gateway/scripts/script_remove_redundant_imports.py
2026-01-23 01:10:00 +01:00

245 lines
7.8 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Remove redundant function-level imports that already exist in the header.
"""
import csv
import re
import ast
from pathlib import Path
from typing import Dict, List, Set, Tuple
from collections import defaultdict
# Paths
SCRIPT_DIR = Path(__file__).parent
GATEWAY_ROOT = SCRIPT_DIR.parent
INPUT_FILE = SCRIPT_DIR / "import_analysis.csv"
def _getContainer(moduleName: str) -> str:
"""Extract container name from module path."""
if moduleName == "gateway.app":
return "app"
parts = moduleName.replace("gateway.", "").split(".")
if len(parts) < 2:
return "app"
container = parts[1]
if container in ("tests", "scripts") or container.startswith("script_"):
return None
if parts[0] in ("tests", "scripts"):
return None
if container == "features" and len(parts) > 2:
return f"features.{parts[2]}"
return container
def _findRedundantImports() -> Dict[str, List[Tuple[str, str]]]:
"""
Find redundant function imports (already in header).
Returns: Dict[source_module] -> List[(target_module, function_name)]
"""
headerImports = defaultdict(set)
functionImports = defaultdict(list)
with open(INPUT_FILE, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
sourceFull = row["module_name"]
targetFull = row["imported_module_name"]
position = row["position"]
if not targetFull.startswith("modules."):
continue
if targetFull.startswith("(relative)"):
continue
sourceContainer = _getContainer(sourceFull)
if sourceContainer is None:
continue
targetFull = f"gateway.{targetFull}"
if position == "header":
headerImports[sourceFull].add(targetFull)
else:
funcName = position.replace("function ", "")
functionImports[sourceFull].append((targetFull, funcName))
# Find redundant
redundant = defaultdict(list)
for source, imports in functionImports.items():
headerSet = headerImports.get(source, set())
for target, funcName in imports:
if target in headerSet:
redundant[source].append((target, funcName))
return dict(redundant)
def _modulePathToFilePath(moduleName: str) -> Path:
"""Convert module name to file path."""
# gateway.modules.xxx.yyy -> modules/xxx/yyy.py
parts = moduleName.replace("gateway.", "").split(".")
filePath = GATEWAY_ROOT
for part in parts:
filePath = filePath / part
return filePath.with_suffix(".py")
def _removeImportFromFunction(filePath: Path, targetModule: str, funcName: str) -> bool:
"""
Remove a specific import statement from inside a function.
Returns True if successful.
"""
if not filePath.exists():
print(f" File not found: {filePath}")
return False
with open(filePath, "r", encoding="utf-8") as f:
content = f.read()
# The import we're looking for (without gateway prefix)
importModule = targetModule.replace("gateway.", "")
# Build regex patterns for different import styles
patterns = [
# from modules.xxx import yyy
rf'(\n[ \t]+)(from {re.escape(importModule)} import [^\n]+)',
# import modules.xxx
rf'(\n[ \t]+)(import {re.escape(importModule)}[^\n]*)',
# from modules.xxx.yyy import zzz (partial match)
rf'(\n[ \t]+)(from {re.escape(importModule.rsplit(".", 1)[0])} import [^\n]*{re.escape(importModule.rsplit(".", 1)[-1])}[^\n]*)',
]
modified = False
for pattern in patterns:
matches = list(re.finditer(pattern, content))
for match in matches:
# Check if this import is inside the target function
# by looking backwards for the function definition
startPos = match.start()
beforeMatch = content[:startPos]
# Find the most recent function definition
funcPattern = rf'def {re.escape(funcName)}\s*\('
funcMatches = list(re.finditer(funcPattern, beforeMatch))
if funcMatches:
lastFuncStart = funcMatches[-1].start()
# Check there's no other function definition between the func and the import
betweenText = beforeMatch[lastFuncStart:]
otherFuncs = re.findall(r'\ndef [a-zA-Z_][a-zA-Z0-9_]*\s*\(', betweenText)
if len(otherFuncs) <= 1: # Only our target function
# Remove this import line
indent = match.group(1)
importLine = match.group(2)
fullMatch = match.group(0)
content = content[:match.start()] + content[match.end():]
modified = True
print(f" Removed: {importLine.strip()}")
break
if modified:
break
if modified:
with open(filePath, "w", encoding="utf-8") as f:
f.write(content)
return True
return False
def _removeImportsWithAst(filePath: Path, redundantImports: List[Tuple[str, str]]) -> int:
"""
Use AST to properly identify and remove redundant imports.
Returns count of removed imports.
"""
if not filePath.exists():
return 0
with open(filePath, "r", encoding="utf-8") as f:
lines = f.readlines()
content = "".join(lines)
try:
tree = ast.parse(content)
except SyntaxError:
print(f" Syntax error in {filePath}")
return 0
# Group by function
importsByFunc = defaultdict(set)
for target, funcName in redundantImports:
importModule = target.replace("gateway.", "")
importsByFunc[funcName].add(importModule)
# Find imports inside functions
linesToRemove = set()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
funcName = node.name
if funcName not in importsByFunc:
continue
targetModules = importsByFunc[funcName]
# Walk the function body for imports
for child in ast.walk(node):
if isinstance(child, ast.ImportFrom):
if child.module and child.module in targetModules:
linesToRemove.add(child.lineno)
elif isinstance(child, ast.Import):
for alias in child.names:
if alias.name in targetModules:
linesToRemove.add(child.lineno)
if not linesToRemove:
return 0
# Remove the lines
newLines = []
for i, line in enumerate(lines, 1):
if i not in linesToRemove:
newLines.append(line)
else:
print(f" Line {i}: {line.strip()}")
with open(filePath, "w", encoding="utf-8") as f:
f.writelines(newLines)
return len(linesToRemove)
def main():
"""Main function."""
print("Finding redundant imports...")
redundant = _findRedundantImports()
totalCount = sum(len(v) for v in redundant.values())
print(f"Found {totalCount} redundant imports in {len(redundant)} files\n")
removedCount = 0
for source, imports in sorted(redundant.items()):
filePath = _modulePathToFilePath(source)
print(f"\n{source}")
print(f" File: {filePath}")
removed = _removeImportsWithAst(filePath, imports)
removedCount += removed
print(f"\n\nTotal removed: {removedCount} imports")
if __name__ == "__main__":
main()