245 lines
7.8 KiB
Python
245 lines
7.8 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Remove redundant function-level imports that already exist in the header.
|
|
"""
|
|
|
|
import csv
|
|
import re
|
|
import ast
|
|
from pathlib import Path
|
|
from typing import Dict, List, Set, Tuple
|
|
from collections import defaultdict
|
|
|
|
# Paths
|
|
SCRIPT_DIR = Path(__file__).parent
|
|
GATEWAY_ROOT = SCRIPT_DIR.parent
|
|
INPUT_FILE = SCRIPT_DIR / "import_analysis.csv"
|
|
|
|
|
|
def _getContainer(moduleName: str) -> str:
|
|
"""Extract container name from module path."""
|
|
if moduleName == "gateway.app":
|
|
return "app"
|
|
|
|
parts = moduleName.replace("gateway.", "").split(".")
|
|
if len(parts) < 2:
|
|
return "app"
|
|
|
|
container = parts[1]
|
|
|
|
if container in ("tests", "scripts") or container.startswith("script_"):
|
|
return None
|
|
if parts[0] in ("tests", "scripts"):
|
|
return None
|
|
|
|
if container == "features" and len(parts) > 2:
|
|
return f"features.{parts[2]}"
|
|
|
|
return container
|
|
|
|
|
|
def _findRedundantImports() -> Dict[str, List[Tuple[str, str]]]:
|
|
"""
|
|
Find redundant function imports (already in header).
|
|
Returns: Dict[source_module] -> List[(target_module, function_name)]
|
|
"""
|
|
headerImports = defaultdict(set)
|
|
functionImports = defaultdict(list)
|
|
|
|
with open(INPUT_FILE, "r", encoding="utf-8") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
sourceFull = row["module_name"]
|
|
targetFull = row["imported_module_name"]
|
|
position = row["position"]
|
|
|
|
if not targetFull.startswith("modules."):
|
|
continue
|
|
if targetFull.startswith("(relative)"):
|
|
continue
|
|
|
|
sourceContainer = _getContainer(sourceFull)
|
|
if sourceContainer is None:
|
|
continue
|
|
|
|
targetFull = f"gateway.{targetFull}"
|
|
|
|
if position == "header":
|
|
headerImports[sourceFull].add(targetFull)
|
|
else:
|
|
funcName = position.replace("function ", "")
|
|
functionImports[sourceFull].append((targetFull, funcName))
|
|
|
|
# Find redundant
|
|
redundant = defaultdict(list)
|
|
for source, imports in functionImports.items():
|
|
headerSet = headerImports.get(source, set())
|
|
for target, funcName in imports:
|
|
if target in headerSet:
|
|
redundant[source].append((target, funcName))
|
|
|
|
return dict(redundant)
|
|
|
|
|
|
def _modulePathToFilePath(moduleName: str) -> Path:
|
|
"""Convert module name to file path."""
|
|
# gateway.modules.xxx.yyy -> modules/xxx/yyy.py
|
|
parts = moduleName.replace("gateway.", "").split(".")
|
|
filePath = GATEWAY_ROOT
|
|
for part in parts:
|
|
filePath = filePath / part
|
|
return filePath.with_suffix(".py")
|
|
|
|
|
|
def _removeImportFromFunction(filePath: Path, targetModule: str, funcName: str) -> bool:
|
|
"""
|
|
Remove a specific import statement from inside a function.
|
|
Returns True if successful.
|
|
"""
|
|
if not filePath.exists():
|
|
print(f" File not found: {filePath}")
|
|
return False
|
|
|
|
with open(filePath, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
|
|
# The import we're looking for (without gateway prefix)
|
|
importModule = targetModule.replace("gateway.", "")
|
|
|
|
# Build regex patterns for different import styles
|
|
patterns = [
|
|
# from modules.xxx import yyy
|
|
rf'(\n[ \t]+)(from {re.escape(importModule)} import [^\n]+)',
|
|
# import modules.xxx
|
|
rf'(\n[ \t]+)(import {re.escape(importModule)}[^\n]*)',
|
|
# from modules.xxx.yyy import zzz (partial match)
|
|
rf'(\n[ \t]+)(from {re.escape(importModule.rsplit(".", 1)[0])} import [^\n]*{re.escape(importModule.rsplit(".", 1)[-1])}[^\n]*)',
|
|
]
|
|
|
|
modified = False
|
|
for pattern in patterns:
|
|
matches = list(re.finditer(pattern, content))
|
|
for match in matches:
|
|
# Check if this import is inside the target function
|
|
# by looking backwards for the function definition
|
|
startPos = match.start()
|
|
beforeMatch = content[:startPos]
|
|
|
|
# Find the most recent function definition
|
|
funcPattern = rf'def {re.escape(funcName)}\s*\('
|
|
funcMatches = list(re.finditer(funcPattern, beforeMatch))
|
|
|
|
if funcMatches:
|
|
lastFuncStart = funcMatches[-1].start()
|
|
# Check there's no other function definition between the func and the import
|
|
betweenText = beforeMatch[lastFuncStart:]
|
|
otherFuncs = re.findall(r'\ndef [a-zA-Z_][a-zA-Z0-9_]*\s*\(', betweenText)
|
|
|
|
if len(otherFuncs) <= 1: # Only our target function
|
|
# Remove this import line
|
|
indent = match.group(1)
|
|
importLine = match.group(2)
|
|
fullMatch = match.group(0)
|
|
content = content[:match.start()] + content[match.end():]
|
|
modified = True
|
|
print(f" Removed: {importLine.strip()}")
|
|
break
|
|
|
|
if modified:
|
|
break
|
|
|
|
if modified:
|
|
with open(filePath, "w", encoding="utf-8") as f:
|
|
f.write(content)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _removeImportsWithAst(filePath: Path, redundantImports: List[Tuple[str, str]]) -> int:
|
|
"""
|
|
Use AST to properly identify and remove redundant imports.
|
|
Returns count of removed imports.
|
|
"""
|
|
if not filePath.exists():
|
|
return 0
|
|
|
|
with open(filePath, "r", encoding="utf-8") as f:
|
|
lines = f.readlines()
|
|
|
|
content = "".join(lines)
|
|
|
|
try:
|
|
tree = ast.parse(content)
|
|
except SyntaxError:
|
|
print(f" Syntax error in {filePath}")
|
|
return 0
|
|
|
|
# Group by function
|
|
importsByFunc = defaultdict(set)
|
|
for target, funcName in redundantImports:
|
|
importModule = target.replace("gateway.", "")
|
|
importsByFunc[funcName].add(importModule)
|
|
|
|
# Find imports inside functions
|
|
linesToRemove = set()
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
funcName = node.name
|
|
if funcName not in importsByFunc:
|
|
continue
|
|
|
|
targetModules = importsByFunc[funcName]
|
|
|
|
# Walk the function body for imports
|
|
for child in ast.walk(node):
|
|
if isinstance(child, ast.ImportFrom):
|
|
if child.module and child.module in targetModules:
|
|
linesToRemove.add(child.lineno)
|
|
elif isinstance(child, ast.Import):
|
|
for alias in child.names:
|
|
if alias.name in targetModules:
|
|
linesToRemove.add(child.lineno)
|
|
|
|
if not linesToRemove:
|
|
return 0
|
|
|
|
# Remove the lines
|
|
newLines = []
|
|
for i, line in enumerate(lines, 1):
|
|
if i not in linesToRemove:
|
|
newLines.append(line)
|
|
else:
|
|
print(f" Line {i}: {line.strip()}")
|
|
|
|
with open(filePath, "w", encoding="utf-8") as f:
|
|
f.writelines(newLines)
|
|
|
|
return len(linesToRemove)
|
|
|
|
|
|
def main():
|
|
"""Main function."""
|
|
print("Finding redundant imports...")
|
|
redundant = _findRedundantImports()
|
|
|
|
totalCount = sum(len(v) for v in redundant.values())
|
|
print(f"Found {totalCount} redundant imports in {len(redundant)} files\n")
|
|
|
|
removedCount = 0
|
|
|
|
for source, imports in sorted(redundant.items()):
|
|
filePath = _modulePathToFilePath(source)
|
|
print(f"\n{source}")
|
|
print(f" File: {filePath}")
|
|
|
|
removed = _removeImportsWithAst(filePath, imports)
|
|
removedCount += removed
|
|
|
|
print(f"\n\nTotal removed: {removedCount} imports")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|