"""S3FileStorage — stores files in any S3-compatible object store.""" import tempfile from collections.abc import AsyncGenerator, AsyncIterator from contextlib import AbstractAsyncContextManager, asynccontextmanager from pathlib import Path from typing import Any import aioboto3 import anyio from botocore.exceptions import ClientError from app.domain.entities.storage import ObjectStat from app.domain.errors import StorageError _EXT_CONTENT_TYPE: dict[str, str] = { "mp3": "audio/mpeg", "flac": "audio/flac", "m4a": "audio/mp4", "aac": "audio/aac", "ogg": "audio/ogg", "opus": "audio/ogg", "wav": "audio/wav", "wma": "audio/x-ms-wma", "aiff": "audio/aiff", "aif": "audio/aiff", } def _not_found(key: str, exc: Exception) -> StorageError: err = StorageError(f"Object not found: {key}") err.__cause__ = exc return err def _is_404(exc: ClientError) -> bool: return exc.response["Error"]["Code"] in ("404", "NoSuchKey") class S3FileStorage: def __init__( self, bucket: str, *, endpoint_url: str | None = None, region_name: str | None = None, access_key: str | None = None, secret_key: str | None = None, ) -> None: self._bucket = bucket self._endpoint_url = endpoint_url self._session: Any = aioboto3.Session( aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=region_name, ) def _client(self) -> Any: return self._session.client("s3", endpoint_url=self._endpoint_url) async def save_file(self, key: str, src_path: Path) -> int: async with await anyio.open_file(src_path, "rb") as f: content: bytes = await f.read() async with self._client() as s3: await s3.put_object(Bucket=self._bucket, Key=key, Body=content) return len(content) async def open_range( self, key: str, start: int, end: int | None ) -> tuple[AsyncIterator[bytes], int]: async with self._client() as s3: try: head = await s3.head_object(Bucket=self._bucket, Key=key) except ClientError as exc: if _is_404(exc): raise _not_found(key, exc) from exc raise StorageError(str(exc)) from exc total_size: int = head["ContentLength"] range_header = f"bytes={start}-{end}" if end is not None else f"bytes={start}-" _bucket = self._bucket _key = key async def _stream() -> AsyncGenerator[bytes]: async with self._client() as s3: try: resp = await s3.get_object(Bucket=_bucket, Key=_key, Range=range_header) except ClientError as exc: raise StorageError(str(exc)) from exc body = resp["Body"] while True: chunk: bytes = await body.read(65536) if not chunk: break yield chunk return _stream(), total_size async def stat(self, key: str) -> ObjectStat: async with self._client() as s3: try: head = await s3.head_object(Bucket=self._bucket, Key=key) except ClientError as exc: if _is_404(exc): raise _not_found(key, exc) from exc raise StorageError(str(exc)) from exc ext = Path(key).suffix.lower().lstrip(".") return ObjectStat( size=head["ContentLength"], content_type=head.get("ContentType") or _EXT_CONTENT_TYPE.get(ext), ) async def exists(self, key: str) -> bool: async with self._client() as s3: try: await s3.head_object(Bucket=self._bucket, Key=key) return True except ClientError as exc: if _is_404(exc): return False raise StorageError(str(exc)) from exc async def delete(self, key: str) -> None: async with self._client() as s3: try: await s3.delete_object(Bucket=self._bucket, Key=key) except ClientError as exc: raise StorageError(str(exc)) from exc async def disk_usage(self) -> None: # Object stores have no fixed-capacity volume to report. return None def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]: return self._as_local_path_cm(key) @asynccontextmanager async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]: suffix = Path(key).suffix with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f: tmp_path = Path(f.name) try: async with self._client() as s3: try: resp = await s3.get_object(Bucket=self._bucket, Key=key) except ClientError as exc: if _is_404(exc): raise _not_found(key, exc) from exc raise StorageError(str(exc)) from exc async with await anyio.open_file(tmp_path, "wb") as out: body = resp["Body"] while True: chunk: bytes = await body.read(65536) if not chunk: break await out.write(chunk) yield tmp_path finally: await anyio.Path(tmp_path).unlink(missing_ok=True)