64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
"""Database saver for saving and updating data in the database."""
|
|
|
|
import pandas as pd
|
|
import logging
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from src.dataprocessor.domain.base_datasaver import BaseDataSaver
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SQLiteDataSaver(BaseDataSaver):
|
|
"""Class responsible for saving data to the SQLite database."""
|
|
|
|
db_url: str
|
|
engine: AsyncEngine
|
|
|
|
@classmethod
|
|
async def create(cls, db_path: str) -> "SQLiteDataSaver":
|
|
"""Create a new instance of DataSaver."""
|
|
# Ensure the directory exists
|
|
db_file = Path(db_path)
|
|
db_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Build the full db url with absolute path (4 slashes total)
|
|
db_url = f"sqlite+aiosqlite:///{db_file.as_posix()}"
|
|
|
|
# Use NullPool to avoid connection pooling issues with SQLite
|
|
engine = create_async_engine(db_url, poolclass=NullPool, future=True)
|
|
return cls(db_url=db_url, engine=engine)
|
|
|
|
async def save_table(
|
|
self, df: pd.DataFrame, table_name: str, overwrite: bool = True
|
|
):
|
|
"""Save the DataFrame to the specified table, overwriting existing data.
|
|
|
|
Args:
|
|
df: The DataFrame to save.
|
|
table_name: The name of the table to save the data to.
|
|
overwrite: If True, overwrite the existing table. If False, append to it.
|
|
"""
|
|
async with self.engine.begin() as conn:
|
|
operation = "replace" if overwrite else "append"
|
|
|
|
# Sync to be compatible with Pandas:
|
|
await conn.run_sync(
|
|
lambda sync_conn: df.to_sql(
|
|
name=table_name,
|
|
con=sync_conn, # Sync required
|
|
if_exists=operation, # Overwrite
|
|
index=False,
|
|
)
|
|
)
|
|
|
|
await self.engine.dispose()
|
|
|
|
logger.info(
|
|
f"Saved {len(df)} rows to table '{table_name}' in database '{self.db_url}'."
|
|
)
|