From db6b5d79855603b0f0dd074fc6c4719d8cb671a1 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Mon, 19 May 2025 01:14:51 +0200 Subject: [PATCH] msft auth integriert --- modules/agents/agentEmail.py | 33 ++++- modules/interfaces/gatewayInterface.py | 170 +++++++++++++------------ modules/interfaces/gatewayModel.py | 2 + modules/routes/routeGeneral.py | 53 +++++++- modules/routes/routeMsft.py | 148 ++++++++++++++------- modules/routes/routeUsers.py | 158 +++++++++++++---------- modules/security/auth.py | 51 +++++--- 7 files changed, 397 insertions(+), 218 deletions(-) diff --git a/modules/agents/agentEmail.py b/modules/agents/agentEmail.py index 19b84948..03fb9018 100644 --- a/modules/agents/agentEmail.py +++ b/modules/agents/agentEmail.py @@ -394,7 +394,7 @@ class AgentEmail(AgentBase): logger.error("No mydom interface available") return None, None - # Get token data from database + # Get token data from database using LucyDOMInterface token_data = self.mydom.getMsftToken() if not token_data: logger.info("No Microsoft token found for user") @@ -409,7 +409,36 @@ class AgentEmail(AgentBase): # Get updated token data after refresh token_data = self.mydom.getMsftToken() - return token_data.get("user_info"), token_data.get("access_token") + # Get user info from token data + user_info = token_data.get("user_info") + if not user_info: + # If user_info is not in token_data, try to get it from the token + headers = { + 'Authorization': f'Bearer {token_data.get("access_token", "")}', + 'Content-Type': 'application/json' + } + try: + response = requests.get('https://graph.microsoft.com/v1.0/me', headers=headers) + if response.status_code == 200: + user_data = response.json() + user_info = { + "name": user_data.get("displayName", ""), + "email": user_data.get("userPrincipalName", ""), + "id": user_data.get("id", "") + } + # Update token data with user info + token_data["user_info"] = user_info + self.mydom.saveMsftToken(token_data) + logger.info(f"Retrieved and stored user info for {user_info.get('name', 'Unknown User')}") + else: + logger.warning(f"Failed to get user info: {response.status_code} - {response.text}") + return None, None + except Exception as e: + logger.error(f"Error getting user info: {str(e)}") + return None, None + + logger.info(f"Retrieved user info for {user_info.get('name', 'Unknown User')}") + return user_info, token_data.get("access_token") except Exception as e: logger.error(f"Error getting current user token: {str(e)}") diff --git a/modules/interfaces/gatewayInterface.py b/modules/interfaces/gatewayInterface.py index 7ed117b3..6d259393 100644 --- a/modules/interfaces/gatewayInterface.py +++ b/modules/interfaces/gatewayInterface.py @@ -115,6 +115,7 @@ class GatewayInterface: "disabled": False, "language": "de", "privilege": "sysadmin", + "authenticationAuthority": "local", "hashedPassword": self._getPasswordHash("The 1st Poweron Admin") # Use a secure password in production! } createdUser = self.db.recordCreate("users", adminUser) @@ -280,16 +281,25 @@ class GatewayInterface: def getUserByUsername(self, username: str) -> Optional[Dict[str, Any]]: """Returns a user by username.""" - # Get all users without mandate filter - users = self.db.getRecordset("users") - for user in users: - if user.get("username") == username: - # Log the fields present in the user record - logger.debug(f"Found user {username} with fields: {list(user.keys())}") - # Return a complete copy of the user record with all fields - return {**user} # Use dict unpacking to ensure we get a complete copy with all fields - logger.debug(f"No user found with username {username}") - return None + try: + # Get users table + users = self.db.getRecordset("users") + if not users: + return None + + # Find user by username + for user in users: + if user.get("username") == username: + logger.info(f"Found user with username {username}") + logger.debug(f"User fields: {list(user.keys())}") + return user + + logger.info(f"No user found with username {username}") + return None + + except Exception as e: + logger.error(f"Error getting user by username: {str(e)}") + return None def getUser(self, _userId: str) -> Optional[Dict[str, Any]]: """Returns a user by ID if user has access.""" @@ -311,78 +321,70 @@ class GatewayInterface: return user - def createUser(self, username: str, password: str, email: str = None, - fullName: str = None, language: str = "de", _mandateId: str = None, - disabled: bool = False, privilege: str = "user") -> Dict[str, Any]: - """Creates a new user if current user has permission.""" - # Validate username - if not username or len(username) < 3: - raise ValueError("Benutzername muss mindestens 3 Zeichen lang sein") + def createUser(self, username: str, password: str = None, email: str = None, fullName: str = None, + language: str = "de", _mandateId: int = None, disabled: bool = False, + privilege: str = "user", authenticationAuthority: str = "local") -> Dict[str, Any]: + """Create a new user""" + try: + # Validate username + if not username: + raise ValueError("Username is required") + + # Check if user already exists with the same authentication authority + existingUser = self.getUserByUsername(username) + if existingUser and existingUser.get("authenticationAuthority") == authenticationAuthority: + raise ValueError(f"Username '{username}' already exists with {authenticationAuthority} authentication") - # Validate password - if not password: - raise ValueError("Passwort ist erforderlich") + # Validate password for local authentication + if authenticationAuthority == "local": + if not password: + raise ValueError("Password is required for local authentication") + if len(password) < 8: + raise ValueError("Password must be at least 8 characters long") - # Password requirements - if len(password) < 8: - raise ValueError("Passwort muss mindestens 8 Zeichen lang sein") - if not any(c.isupper() for c in password): - raise ValueError("Passwort muss mindestens einen Grossbuchstaben enthalten") - if not any(c.islower() for c in password): - raise ValueError("Passwort muss mindestens einen Kleinbuchstaben enthalten") - if not any(c.isdigit() for c in password): - raise ValueError("Passwort muss mindestens eine Zahl enthalten") - if not any(c in "!@#$%^&*(),.?\":{}|<>" for c in password): - raise ValueError("Passwort muss mindestens ein Sonderzeichen enthalten") + # Create user data + userData = { + "username": username, + "email": email, + "fullName": fullName, + "language": language, + "_mandateId": _mandateId or self._mandateId, + "disabled": disabled, + "privilege": privilege, + "authenticationAuthority": authenticationAuthority + } - # Validate email if provided - if email: - import re - email_pattern = r'^[^\s@]+@[^\s@]+\.[^\s@]+$' - if not re.match(email_pattern, email): - raise ValueError("Ungültiges E-Mail-Format") - - # Check if the username already exists - existingUser = self.getUserByUsername(username) - if existingUser: - raise ValueError(f"Benutzer '{username}' existiert bereits") + # Add password hash for local authentication + if authenticationAuthority == "local": + userData["hashedPassword"] = self._getPasswordHash(password) - # Use the provided _mandateId or the current context - userMandateId = _mandateId if _mandateId is not None else self._mandateId - - # Check if user has access to the mandate - if userMandateId != self._mandateId and self.currentUser.get("privilege") != "sysadmin": - raise PermissionError(f"Keine Berechtigung, Benutzer in Mandat {userMandateId} zu erstellen") + # Create user record + createdRecord = self.db.recordCreate("users", userData) + if not createdRecord or not createdRecord.get("id"): + raise ValueError("Failed to create user record") - if not self._canModify("users"): - raise PermissionError("Keine Berechtigung, Benutzer zu erstellen") + # Get created user using the returned ID + createdUser = self.db.getRecordset("users", recordFilter={"id": createdRecord["id"]}) + if not createdUser or len(createdUser) == 0: + # Try to get user by username as fallback + createdUser = self.db.getRecordset("users", recordFilter={"username": userData["username"]}) + if not createdUser or len(createdUser) == 0: + raise ValueError("Failed to retrieve created user") - # Check privilege escalation - if (privilege == "sysadmin" or - (privilege == "admin" and self.currentUser.get("privilege") == "user")): - raise PermissionError(f"Keine Berechtigung, Benutzer mit höherem Privileg zu erstellen: {privilege}") - - userData = { - "_mandateId": userMandateId, - "username": username, - "email": email, - "fullName": fullName, - "disabled": disabled, - "language": language, - "privilege": privilege, - "hashedPassword": self._getPasswordHash(password) - } - - createdUser = self.db.recordCreate("users", userData) - - # Clear the users table from cache to ensure fresh data - if "users" in self.db._tablesCache: - del self.db._tablesCache["users"] - - # Return the complete user record - return createdUser + # Clear users table from cache + if hasattr(self.db, '_tablesCache') and "users" in self.db._tablesCache: + del self.db._tablesCache["users"] + + return createdUser[0] + + except ValueError as e: + logger.error(f"Error creating user: {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error creating user: {str(e)}") + raise ValueError(f"Failed to create user: {str(e)}") - def authenticateUser(self, username: str, password: str) -> Optional[Dict[str, Any]]: + def authenticateUser(self, username: str, password: str = None) -> Optional[Dict[str, Any]]: """Authenticates a user by username and password.""" # Clear the users table from cache and reload it if "users" in self.db._tablesCache: @@ -394,12 +396,24 @@ class GatewayInterface: if not user: raise ValueError("Benutzer nicht gefunden") - if not self._verifyPassword(password, user.get("hashedPassword", "")): - raise ValueError("Falsches Passwort") - # Check if the user is disabled if user.get("disabled", False): raise ValueError("Benutzer ist deaktiviert") + + # Handle authentication based on authority + auth_authority = user.get("authenticationAuthority", "local") + + if auth_authority == "local": + if not password: + raise ValueError("Passwort ist erforderlich") + if not self._verifyPassword(password, user.get("hashedPassword", "")): + raise ValueError("Falsches Passwort") + elif auth_authority == "microsoft": + # For Microsoft users, we don't verify the password here + # The authentication is handled by the Microsoft OAuth flow + pass + else: + raise ValueError(f"Unbekannte Authentifizierungsmethode: {auth_authority}") # Create a copy without password hash authenticatedUser = {**user} diff --git a/modules/interfaces/gatewayModel.py b/modules/interfaces/gatewayModel.py index 86f371bc..bde39c60 100644 --- a/modules/interfaces/gatewayModel.py +++ b/modules/interfaces/gatewayModel.py @@ -46,6 +46,7 @@ class User(BaseModel): language: str = Field(description="Preferred language of the user") disabled: Optional[bool] = Field(False, description="Indicates whether the user is disabled") privilege: str = Field(description="Permission level") #sysadmin,admin,user + authenticationAuthority: str = Field(default="local", description="Authentication authority (local, microsoft)") label: Label = Field( default=Label(default="User", translations={"en": "User", "fr": "Utilisateur"}), @@ -62,6 +63,7 @@ class User(BaseModel): "language": Label(default="Language", translations={"en": "Language", "fr": "Langue"}), "disabled": Label(default="Disabled", translations={"en": "Disabled", "fr": "Désactivé"}), "privilege": Label(default="Permission level", translations={"en": "Access level", "fr": "Niveau d'accès"}), + "authenticationAuthority": Label(default="Authentication Authority", translations={"en": "Authentication Authority", "fr": "Autorité d'authentification"}) } diff --git a/modules/routes/routeGeneral.py b/modules/routes/routeGeneral.py index ad1562ae..3ce73501 100644 --- a/modules/routes/routeGeneral.py +++ b/modules/routes/routeGeneral.py @@ -30,7 +30,7 @@ router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), nam logger = logging.getLogger(__name__) -@router.get("/favicon.ico") +@router.get("/favicon.ico", tags=["General"]) async def favicon(): return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon") @@ -83,12 +83,13 @@ async def loginForAccessToken(formData: OAuth2PasswordRequestForm = Depends()): data={ "sub": user["username"], "_mandateId": str(user["_mandateId"]), # Ensure string - "_userId": str(user["id"]) # Ensure string + "_userId": str(user["id"]), # Ensure string + "authenticationAuthority": user.get("authenticationAuthority", "local") # Add auth authority }, expiresDelta=accessTokenExpires ) - logger.info(f"User {user['username']} successfully logged in with context: _mandateId={user['_mandateId']}, _userId={user['id']}") + logger.info(f"User {user['username']} successfully logged in with context: _mandateId={user['_mandateId']}, _userId={user['id']}, auth={user.get('authenticationAuthority', 'local')}") return {"accessToken": accessToken, "tokenType": "bearer"} except ValueError as e: # Handle authentication errors @@ -195,4 +196,50 @@ async def registerUser(userData: Dict[str, Any]): raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to register user" + ) + +@router.get("/api/user/available", response_model=Dict[str, Any], tags=["General"]) +async def checkUsernameAvailability( + username: str, + authenticationAuthority: str = "local" +): + """Check if a username is available for registration""" + try: + # Get root mandate and admin user IDs + adminGateway = getGatewayInterface() + rootMandateId = adminGateway.getInitialId("mandates") + adminUserId = adminGateway.getInitialId("users") + + if not rootMandateId or not adminUserId: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="System is not properly initialized with root mandate and admin user" + ) + + # Create a new gateway interface instance with admin context + adminGateway = getGatewayInterface(rootMandateId, adminUserId) + + # Check if user exists + existingUser = adminGateway.getUserByUsername(username) + + if not existingUser: + return {"available": True} + + # If user exists, check authentication authority + if existingUser.get("authenticationAuthority") == authenticationAuthority: + return { + "available": False, + "message": f"Username already exists with {authenticationAuthority} authentication" + } + else: + return { + "available": True, + "message": f"Username exists but with different authentication authority" + } + + except Exception as e: + logger.error(f"Error checking username availability: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to check username availability: {str(e)}" ) \ No newline at end of file diff --git a/modules/routes/routeMsft.py b/modules/routes/routeMsft.py index 8401dee8..533b59ef 100644 --- a/modules/routes/routeMsft.py +++ b/modules/routes/routeMsft.py @@ -62,6 +62,13 @@ async def save_token_to_file(token_data, currentUser: Dict[str, Any]): logger.error("No LucyDOM interface available for token storage") return False + # Ensure user info is preserved + if "user_info" not in token_data: + # Try to get user info from the token + user_info = get_user_info_from_token(token_data.get("access_token", "")) + if user_info: + token_data["user_info"] = user_info + # Save token to database success = mydom.saveMsftToken(token_data) if success: @@ -217,18 +224,18 @@ async def login(): async def auth_callback(code: str, state: str, request: Request): """Handle Microsoft OAuth callback""" try: - # Create MSAL app instance - app = msal.ConfidentialClientApplication( - client_id=CLIENT_ID, - client_credential=CLIENT_SECRET, - authority=AUTHORITY + # Create a confidential client application + msal_app = msal.ConfidentialClientApplication( + app_config["client_id"], + authority=app_config["authority"], + client_credential=app_config["client_credential"] ) - # Exchange code for token - token_response = app.acquire_token_by_authorization_code( - code=code, - scopes=SCOPES, - redirect_uri=REDIRECT_URI + # Exchange the authorization code for tokens + token_response = msal_app.acquire_token_by_authorization_code( + code, + SCOPES, + redirect_uri=app_config["redirect_uri"] ) if "error" in token_response: @@ -245,7 +252,7 @@ async def auth_callback(code: str, state: str, request: Request):

Authentication Failed

-

Please try again.

+

Could not acquire access token.

@@ -255,8 +262,9 @@ async def auth_callback(code: str, state: str, request: Request): status_code=400 ) - # Get user info from token + # Get user info from the token user_info = get_user_info_from_token(token_response["access_token"]) + if not user_info: logger.error("Failed to get user info from token") return HTMLResponse( @@ -281,7 +289,72 @@ async def auth_callback(code: str, state: str, request: Request): status_code=400 ) - # Add user info to token data + # Get gateway interface for user operations + gateway = getGatewayInterface() + + # Check if user exists + user = gateway.getUserByUsername(user_info["email"]) + + # If user doesn't exist, create a new user in the default mandate + if not user: + try: + # Get the root mandate ID + rootMandateId = gateway.getInitialId("mandates") + if not rootMandateId: + raise ValueError("Root mandate not found") + + # Create new user with Microsoft authentication + user = gateway.createUser( + username=user_info["email"], + email=user_info["email"], + fullName=user_info.get("name", user_info["email"]), + _mandateId=rootMandateId, + authenticationAuthority="microsoft" + ) + logger.info(f"Created new user for Microsoft account: {user_info['email']}") + + # Verify user was created by retrieving it + user = gateway.getUserByUsername(user_info["email"]) + if not user: + raise ValueError("Failed to retrieve created user") + + except Exception as e: + logger.error(f"Failed to create user for Microsoft account: {str(e)}") + return HTMLResponse( + content=""" + + + Registration Failed + + + +

Registration Failed

+

Could not create user account.

+ + + + """, + status_code=400 + ) + + # Create backend token + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = createAccessToken( + data={ + "sub": user["username"], + "_mandateId": str(user["_mandateId"]), + "_userId": str(user["id"]), + "authenticationAuthority": "microsoft" + }, + expiresDelta=access_token_expires + ) + + # Add user info to token response token_response["user_info"] = user_info # Store tokens in session storage for the frontend to pick up @@ -308,7 +381,8 @@ async def auth_callback(code: str, state: str, request: Request): window.opener.postMessage({{ type: 'msft_auth_success', user: {json.dumps(user_info)}, - token_data: {json.dumps(token_response)} + token_data: {json.dumps(token_response)}, + access_token: "{access_token}" }}, '*'); }} // Close window after 3 seconds @@ -322,27 +396,10 @@ async def auth_callback(code: str, state: str, request: Request): return response except Exception as e: - logger.error(f"Authentication failed: {str(e)}") - return HTMLResponse( - content=""" - - - Authentication Failed - - - -

Authentication Failed

-

An error occurred during authentication.

- - - - """, - status_code=500 + logger.error(f"Error in auth callback: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Authentication failed: {str(e)}" ) @router.get("/status") @@ -368,8 +425,9 @@ async def auth_status(currentUser: Dict[str, Any] = Depends(getCurrentActiveUser "message": "Not authenticated with Microsoft" }) - # Verify token is still valid - if not verify_token(token_data["access_token"]): + # Verify token is still valid and get user info + user_info = get_user_info_from_token(token_data["access_token"]) + if not user_info: logger.info("Token invalid, attempting refresh") # Try to refresh the token if not await refresh_token(_userId, currentUser): @@ -380,15 +438,13 @@ async def auth_status(currentUser: Dict[str, Any] = Depends(getCurrentActiveUser }) # Reload token data after refresh token_data = await load_token_from_file(currentUser) - - # Get user info from token data - user_info = token_data.get("user_info") - if not user_info: - logger.info("No user info found in token data") - return JSONResponse({ - "authenticated": False, - "message": "No user information available" - }) + # Get user info again after refresh + user_info = get_user_info_from_token(token_data["access_token"]) + if not user_info: + return JSONResponse({ + "authenticated": False, + "message": "Could not get user info after token refresh" + }) logger.info(f"User {user_info.get('name')} is authenticated") return JSONResponse({ diff --git a/modules/routes/routeUsers.py b/modules/routes/routeUsers.py index 468fe867..57d5bf0e 100644 --- a/modules/routes/routeUsers.py +++ b/modules/routes/routeUsers.py @@ -76,8 +76,7 @@ async def registerUser(request: Request): """Register a new user.""" try: # Get request data - data = await request.json() - logger.info(f"Registration request data: {data}") + userData = await request.json() # Get root mandate and admin user IDs adminGateway = getGatewayInterface() @@ -86,91 +85,110 @@ async def registerUser(request: Request): if not rootMandateId or not adminUserId: raise HTTPException( - status_code=500, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="System is not properly initialized with root mandate and admin user" ) # Create a new gateway interface instance with admin context adminGateway = getGatewayInterface(rootMandateId, adminUserId) - # Check required fields - if not data.get("username") or not data.get("password"): - logger.error("Missing required fields in registration request") - raise HTTPException(status_code=400, detail="Username and password are required") - - # Create user data - userData = { - "username": data["username"], - "password": data["password"], - "email": data.get("email"), - "fullName": data.get("fullName"), - "language": data.get("language", "de"), - "_mandateId": rootMandateId, - "disabled": False, - "privilege": "user" - } + # Set default values if not provided + if "language" not in userData: + userData["language"] = "en" + if "authenticationAuthority" not in userData: + userData["authenticationAuthority"] = "local" + + # Validate authentication authority + if userData["authenticationAuthority"] not in ["local", "microsoft"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid authentication authority: {userData['authenticationAuthority']}" + ) + + # Validate password for local authentication + if userData["authenticationAuthority"] == "local": + if "password" not in userData: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password is required for local authentication" + ) # Create the user - logger.info(f"Attempting to create user with data: {userData}") - createdUser = adminGateway.createUser(**userData) - logger.info(f"User created successfully: {createdUser}") - - # Add a small delay to ensure database consistency - time.sleep(0.5) - - # Verify the user was created and password was stored - if "hashedPassword" not in createdUser: - logger.error("Password not stored in user record") - # Try to delete the user - try: - adminGateway.deleteUser(createdUser["id"]) - logger.info("Successfully deleted user after password storage failure") - except Exception as e: - logger.error(f"Failed to delete user after password storage failure: {str(e)}") - raise HTTPException(status_code=500, detail="Password storage failed") - - logger.info("User verification successful") - - # Test authentication try: - authResult = adminGateway.authenticateUser(userData["username"], userData["password"]) - if not authResult: - logger.error("Authentication test failed after user creation") + createdUser = adminGateway.createUser( + username=userData["username"], + password=userData.get("password"), + email=userData.get("email"), + fullName=userData.get("fullName"), + language=userData["language"], + _mandateId=userData.get("_mandateId", rootMandateId), + authenticationAuthority=userData["authenticationAuthority"] + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + + # Verify the user was created + if not createdUser: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create user" + ) + + # For local authentication, verify password was stored + if userData["authenticationAuthority"] == "local": + if "hashedPassword" not in createdUser: + logger.error("Password not stored in user record") + # Try to delete the user + try: + adminGateway.deleteUser(createdUser["id"]) + logger.info("Successfully deleted user after password storage failure") + except Exception as e: + logger.error(f"Failed to delete user after password storage failure: {str(e)}") + raise HTTPException(status_code=500, detail="Password storage failed") + + logger.info("User verification successful") + + # Test authentication + try: + authResult = adminGateway.authenticateUser(userData["username"], userData["password"]) + if not authResult: + logger.error("Authentication test failed after user creation") + # Try to delete the user + try: + adminGateway.deleteUser(createdUser["id"]) + logger.info("Successfully deleted user after authentication test failure") + except Exception as e: + logger.error(f"Failed to delete user after authentication test failure: {str(e)}") + raise HTTPException(status_code=500, detail="Authentication test failed") + except ValueError as e: + logger.error(f"Authentication test failed: {str(e)}") # Try to delete the user try: adminGateway.deleteUser(createdUser["id"]) logger.info("Successfully deleted user after authentication test failure") except Exception as e: logger.error(f"Failed to delete user after authentication test failure: {str(e)}") - raise HTTPException(status_code=500, detail="Authentication test failed") - except ValueError as e: - logger.error(f"Authentication test failed: {str(e)}") - # Try to delete the user - try: - adminGateway.deleteUser(createdUser["id"]) - logger.info("Successfully deleted user after authentication test failure") - except Exception as e: - logger.error(f"Failed to delete user after authentication test failure: {str(e)}") - raise HTTPException(status_code=500, detail=f"Authentication test failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Authentication test failed: {str(e)}") + + logger.info("Authentication test successful") + + # Remove sensitive data from response + if "hashedPassword" in createdUser: + del createdUser["hashedPassword"] - logger.info("Authentication test successful") + return createdUser - # Return success response - return { - "message": "User registered successfully", - "userId": createdUser["id"] - } - - except ValueError as e: - logger.error(f"Validation error during registration: {str(e)}") - raise HTTPException(status_code=400, detail=str(e)) - except PermissionError as e: - logger.error(f"Permission error during registration: {str(e)}") - raise HTTPException(status_code=403, detail=str(e)) + except HTTPException: + raise except Exception as e: - logger.error(f"Unexpected error during registration: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="Internal server error") + logger.error(f"Unexpected error during user registration: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Registration failed: {str(e)}" + ) @router.post("/register-with-msal", response_model=Dict[str, Any]) async def registerUserWithMsal(userData: dict = Body(...)): @@ -372,4 +390,4 @@ async def deleteUser( detail=f"Error deleting user with ID {userId}" ) - return None \ No newline at end of file + return None diff --git a/modules/security/auth.py b/modules/security/auth.py index 9d15e4c2..f9c5530b 100644 --- a/modules/security/auth.py +++ b/modules/security/auth.py @@ -106,25 +106,10 @@ async def getCurrentUser(token: str = Depends(oauth2Scheme)) -> Dict[str, Any]: logger.error(f"User context mismatch: token(_mandateId={_mandateId}, _userId={_userId}) vs user(_mandateId={user.get('_mandateId')}, id={user.get('id')})") raise credentialsException - return user - -async def getCurrentActiveUser(currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, Any]: - """ - Ensures that the user is active. + # Add authentication authority to user data + user["authenticationAuthority"] = user.get("authenticationAuthority", "local") - Args: - currentUser: Current user data - - Returns: - User data - - Raises: - HTTPException: If the user is disabled - """ - if currentUser.get("disabled", False): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled") - - return currentUser + return user async def getUserContext(currentUser: Dict[str, Any]) -> Tuple[str, str]: """ @@ -170,4 +155,32 @@ def getInitialContext() -> tuple[str, str]: gateway = getGatewayInterface() mandateId = gateway.getInitialId("mandates") userId = gateway.getInitialId("users") - return mandateId, userId \ No newline at end of file + return mandateId, userId + +async def getCurrentActiveUser(currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, Any]: + """ + Gets the current active user and verifies their authentication authority. + + Args: + currentUser: The current user from getCurrentUser + + Returns: + The current user data + + Raises: + HTTPException: If user is disabled or has invalid authentication authority + """ + if currentUser.get("disabled", False): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User is disabled" + ) + + auth_authority = currentUser.get("authenticationAuthority", "local") + if auth_authority not in ["local", "microsoft"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Invalid authentication authority: {auth_authority}" + ) + + return currentUser \ No newline at end of file