# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Remove redundant function-level imports that already exist in the header. """ import csv import re import ast from pathlib import Path from typing import Dict, List, Set, Tuple from collections import defaultdict # Paths SCRIPT_DIR = Path(__file__).parent GATEWAY_ROOT = SCRIPT_DIR.parent INPUT_FILE = SCRIPT_DIR / "import_analysis.csv" def _getContainer(moduleName: str) -> str: """Extract container name from module path.""" if moduleName == "gateway.app": return "app" parts = moduleName.replace("gateway.", "").split(".") if len(parts) < 2: return "app" container = parts[1] if container in ("tests", "scripts") or container.startswith("script_"): return None if parts[0] in ("tests", "scripts"): return None if container == "features" and len(parts) > 2: return f"features.{parts[2]}" return container def _findRedundantImports() -> Dict[str, List[Tuple[str, str]]]: """ Find redundant function imports (already in header). Returns: Dict[source_module] -> List[(target_module, function_name)] """ headerImports = defaultdict(set) functionImports = defaultdict(list) with open(INPUT_FILE, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: sourceFull = row["module_name"] targetFull = row["imported_module_name"] position = row["position"] if not targetFull.startswith("modules."): continue if targetFull.startswith("(relative)"): continue sourceContainer = _getContainer(sourceFull) if sourceContainer is None: continue targetFull = f"gateway.{targetFull}" if position == "header": headerImports[sourceFull].add(targetFull) else: funcName = position.replace("function ", "") functionImports[sourceFull].append((targetFull, funcName)) # Find redundant redundant = defaultdict(list) for source, imports in functionImports.items(): headerSet = headerImports.get(source, set()) for target, funcName in imports: if target in headerSet: redundant[source].append((target, funcName)) return dict(redundant) def _modulePathToFilePath(moduleName: str) -> Path: """Convert module name to file path.""" # gateway.modules.xxx.yyy -> modules/xxx/yyy.py parts = moduleName.replace("gateway.", "").split(".") filePath = GATEWAY_ROOT for part in parts: filePath = filePath / part return filePath.with_suffix(".py") def _removeImportFromFunction(filePath: Path, targetModule: str, funcName: str) -> bool: """ Remove a specific import statement from inside a function. Returns True if successful. """ if not filePath.exists(): print(f" File not found: {filePath}") return False with open(filePath, "r", encoding="utf-8") as f: content = f.read() # The import we're looking for (without gateway prefix) importModule = targetModule.replace("gateway.", "") # Build regex patterns for different import styles patterns = [ # from modules.xxx import yyy rf'(\n[ \t]+)(from {re.escape(importModule)} import [^\n]+)', # import modules.xxx rf'(\n[ \t]+)(import {re.escape(importModule)}[^\n]*)', # from modules.xxx.yyy import zzz (partial match) rf'(\n[ \t]+)(from {re.escape(importModule.rsplit(".", 1)[0])} import [^\n]*{re.escape(importModule.rsplit(".", 1)[-1])}[^\n]*)', ] modified = False for pattern in patterns: matches = list(re.finditer(pattern, content)) for match in matches: # Check if this import is inside the target function # by looking backwards for the function definition startPos = match.start() beforeMatch = content[:startPos] # Find the most recent function definition funcPattern = rf'def {re.escape(funcName)}\s*\(' funcMatches = list(re.finditer(funcPattern, beforeMatch)) if funcMatches: lastFuncStart = funcMatches[-1].start() # Check there's no other function definition between the func and the import betweenText = beforeMatch[lastFuncStart:] otherFuncs = re.findall(r'\ndef [a-zA-Z_][a-zA-Z0-9_]*\s*\(', betweenText) if len(otherFuncs) <= 1: # Only our target function # Remove this import line indent = match.group(1) importLine = match.group(2) fullMatch = match.group(0) content = content[:match.start()] + content[match.end():] modified = True print(f" Removed: {importLine.strip()}") break if modified: break if modified: with open(filePath, "w", encoding="utf-8") as f: f.write(content) return True return False def _removeImportsWithAst(filePath: Path, redundantImports: List[Tuple[str, str]]) -> int: """ Use AST to properly identify and remove redundant imports. Returns count of removed imports. """ if not filePath.exists(): return 0 with open(filePath, "r", encoding="utf-8") as f: lines = f.readlines() content = "".join(lines) try: tree = ast.parse(content) except SyntaxError: print(f" Syntax error in {filePath}") return 0 # Group by function importsByFunc = defaultdict(set) for target, funcName in redundantImports: importModule = target.replace("gateway.", "") importsByFunc[funcName].add(importModule) # Find imports inside functions linesToRemove = set() for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): funcName = node.name if funcName not in importsByFunc: continue targetModules = importsByFunc[funcName] # Walk the function body for imports for child in ast.walk(node): if isinstance(child, ast.ImportFrom): if child.module and child.module in targetModules: linesToRemove.add(child.lineno) elif isinstance(child, ast.Import): for alias in child.names: if alias.name in targetModules: linesToRemove.add(child.lineno) if not linesToRemove: return 0 # Remove the lines newLines = [] for i, line in enumerate(lines, 1): if i not in linesToRemove: newLines.append(line) else: print(f" Line {i}: {line.strip()}") with open(filePath, "w", encoding="utf-8") as f: f.writelines(newLines) return len(linesToRemove) def main(): """Main function.""" print("Finding redundant imports...") redundant = _findRedundantImports() totalCount = sum(len(v) for v in redundant.values()) print(f"Found {totalCount} redundant imports in {len(redundant)} files\n") removedCount = 0 for source, imports in sorted(redundant.items()): filePath = _modulePathToFilePath(source) print(f"\n{source}") print(f" File: {filePath}") removed = _removeImportsWithAst(filePath, imports) removedCount += removed print(f"\n\nTotal removed: {removedCount} imports") if __name__ == "__main__": main()