"""Download job repository — adapter over ``AsyncSession`` (plan §6.1).""" import datetime as dt import uuid from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.domain.entities.download import DownloadJob from app.infrastructure.db.models.download_job import DownloadJobModel from app.infrastructure.db.models.enums import DownloadStatus # Jobs that are not yet finished — used to dedup an in-flight download. _ACTIVE_STATUSES = ( DownloadStatus.QUEUED.value, DownloadStatus.DOWNLOADING.value, DownloadStatus.ENRICHING.value, ) def _to_entity(row: DownloadJobModel) -> DownloadJob: return DownloadJob( id=row.id, source=row.source, source_id=row.source_id, query=row.query, requested_by=row.requested_by, status=row.status, progress=row.progress, error_message=row.error_message, retry_count=row.retry_count, track_id=row.track_id, created_at=row.created_at, updated_at=row.updated_at, ) class SqlAlchemyDownloadJobRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def add( self, *, source: str, source_id: str | None, query: str | None, requested_by: uuid.UUID | None, ) -> DownloadJob: row = DownloadJobModel( source=source, source_id=source_id, query=query, requested_by=requested_by, status=DownloadStatus.QUEUED.value, progress=0.0, retry_count=0, ) self._session.add(row) await self._session.flush() await self._session.refresh(row) return _to_entity(row) async def get_by_id(self, job_id: uuid.UUID) -> DownloadJob | None: row = await self._session.get(DownloadJobModel, job_id) return _to_entity(row) if row is not None else None async def get_active_for_source(self, source: str, source_id: str) -> DownloadJob | None: row = ( await self._session.execute( select(DownloadJobModel) .where( DownloadJobModel.source == source, DownloadJobModel.source_id == source_id, DownloadJobModel.status.in_(_ACTIVE_STATUSES), ) .order_by(DownloadJobModel.created_at.desc()) .limit(1) ) ).scalar_one_or_none() return _to_entity(row) if row is not None else None async def list( self, *, requested_by: uuid.UUID | None, status: str | None, limit: int, offset: int, ) -> list[DownloadJob]: stmt = select(DownloadJobModel) if requested_by is not None: stmt = stmt.where(DownloadJobModel.requested_by == requested_by) if status is not None: stmt = stmt.where(DownloadJobModel.status == status) stmt = stmt.order_by(DownloadJobModel.created_at.desc()).limit(limit).offset(offset) rows = (await self._session.execute(stmt)).scalars().all() return [_to_entity(r) for r in rows] async def count(self, *, requested_by: uuid.UUID | None, status: str | None) -> int: stmt = select(func.count()).select_from(DownloadJobModel) if requested_by is not None: stmt = stmt.where(DownloadJobModel.requested_by == requested_by) if status is not None: stmt = stmt.where(DownloadJobModel.status == status) return (await self._session.execute(stmt)).scalar_one() async def set_status( self, job_id: uuid.UUID, *, status: str, error_message: str | None = None, track_id: uuid.UUID | None = None, ) -> None: row = await self._session.get(DownloadJobModel, job_id) if row is None: return row.status = status # ``error_message`` is always written: a successful transition clears a # stale reason from an earlier failed attempt. row.error_message = error_message if track_id is not None: row.track_id = track_id if status == DownloadStatus.DONE.value: row.progress = 1.0 await self._session.flush() async def set_progress(self, job_id: uuid.UUID, progress: float) -> None: row = await self._session.get(DownloadJobModel, job_id) if row is None: return row.progress = max(0.0, min(1.0, progress)) await self._session.flush() async def increment_retry(self, job_id: uuid.UUID) -> int: row = await self._session.get(DownloadJobModel, job_id) if row is None: return 0 row.retry_count += 1 await self._session.flush() return row.retry_count async def delete(self, job_id: uuid.UUID) -> None: row = await self._session.get(DownloadJobModel, job_id) if row is not None: await self._session.delete(row) await self._session.flush() async def failure_rate(self, source: str, *, since: dt.datetime) -> float: total, failed = ( await self._session.execute( select( func.count(), func.count().filter(DownloadJobModel.status == DownloadStatus.FAILED.value), ) .select_from(DownloadJobModel) .where( DownloadJobModel.source == source, DownloadJobModel.created_at >= since, ) ) ).one() return (failed / total) if total else 0.0