From 81ea93c371f2a32e52befbfeda1791de04b611da Mon Sep 17 00:00:00 2001 From: Senko-san Date: Sun, 7 Jun 2026 15:34:06 +0300 Subject: [PATCH] feat: local storage logic & endpoints --- app/api/deps.py | 32 ++- app/api/errors.py | 11 ++ app/api/schemas/upload.py | 11 ++ app/api/v1/streaming.py | 36 +++- app/api/v1/upload.py | 21 +- app/application/streaming_service.py | 97 +++++++++ app/application/upload_service.py | 116 +++++++++++ app/core/config.py | 9 + app/core/hashing.py | 12 ++ app/domain/entities/__init__.py | 4 +- app/domain/entities/storage.py | 9 + app/domain/entities/track.py | 29 +++ app/domain/errors.py | 16 ++ app/domain/ports.py | 41 +++- .../db/repositories/__init__.py | 9 +- .../db/repositories/artist_repository.py | 32 +++ .../db/repositories/track_repository.py | 83 ++++++++ app/infrastructure/storage/__init__.py | 1 + app/infrastructure/storage/local.py | 86 ++++++++ app/infrastructure/storage/provider.py | 16 ++ tests/test_auth_api.py | 3 +- tests/test_storage_local.py | 104 ++++++++++ tests/test_upload_stream_api.py | 185 ++++++++++++++++++ 23 files changed, 945 insertions(+), 18 deletions(-) create mode 100644 app/api/schemas/upload.py create mode 100644 app/application/streaming_service.py create mode 100644 app/application/upload_service.py create mode 100644 app/core/hashing.py create mode 100644 app/domain/entities/storage.py create mode 100644 app/domain/entities/track.py create mode 100644 app/infrastructure/db/repositories/artist_repository.py create mode 100644 app/infrastructure/db/repositories/track_repository.py create mode 100644 app/infrastructure/storage/__init__.py create mode 100644 app/infrastructure/storage/local.py create mode 100644 app/infrastructure/storage/provider.py create mode 100644 tests/test_storage_local.py create mode 100644 tests/test_upload_stream_api.py diff --git a/app/api/deps.py b/app/api/deps.py index 06d7704..a5c6036 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -15,17 +15,22 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.ext.asyncio import AsyncSession from app.application.auth_service import AuthService +from app.application.streaming_service import StreamingService +from app.application.upload_service import UploadService from app.application.user_service import UserService from app.core.config import get_settings from app.core.security import Argon2PasswordHasher, JwtTokenService from app.domain.entities import User from app.domain.errors import AuthenticationError, PermissionDeniedError -from app.domain.ports import PasswordHasher, TokenService +from app.domain.ports import FileStorage, PasswordHasher, TokenService from app.infrastructure.db import get_sessionmaker from app.infrastructure.db.repositories import ( + SqlAlchemyArtistRepository, SqlAlchemyRefreshTokenRepository, + SqlAlchemyTrackRepository, SqlAlchemyUserRepository, ) +from app.infrastructure.storage.provider import get_file_storage async def get_session() -> AsyncIterator[AsyncSession]: @@ -77,6 +82,31 @@ AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)] UserServiceDep = Annotated[UserService, Depends(get_user_service)] +# -- file storage (process-cached) --------------------------------------------- +FileStorageDep = Annotated[FileStorage, Depends(get_file_storage)] + + +def get_upload_service(session: SessionDep, storage: FileStorageDep) -> UploadService: + settings = get_settings() + return UploadService( + tracks=SqlAlchemyTrackRepository(session), + artists=SqlAlchemyArtistRepository(session), + storage=storage, + tmp_dir=settings.upload_tmp_dir, + ) + + +def get_streaming_service(session: SessionDep, storage: FileStorageDep) -> StreamingService: + return StreamingService( + tracks=SqlAlchemyTrackRepository(session), + storage=storage, + ) + + +UploadServiceDep = Annotated[UploadService, Depends(get_upload_service)] +StreamingServiceDep = Annotated[StreamingService, Depends(get_streaming_service)] + + # -- current user / authorization ---------------------------------------------- # auto_error=False: we raise domain AuthenticationError (mapped to 401) so the # error envelope stays consistent with the rest of the API. diff --git a/app/api/errors.py b/app/api/errors.py index 327fc3f..5114e22 100644 --- a/app/api/errors.py +++ b/app/api/errors.py @@ -12,6 +12,8 @@ from app.domain.errors import ( DomainError, NotFoundError, PermissionDeniedError, + RangeNotSatisfiableError, + StorageError, ValidationError, ) @@ -25,6 +27,7 @@ _STATUS_BY_ERROR: dict[type[DomainError], int] = { AuthenticationError: status.HTTP_401_UNAUTHORIZED, PermissionDeniedError: status.HTTP_403_FORBIDDEN, DependencyUnavailableError: status.HTTP_503_SERVICE_UNAVAILABLE, + StorageError: status.HTTP_500_INTERNAL_SERVER_ERROR, } @@ -33,6 +36,14 @@ def _error_body(code: str, message: str) -> dict[str, dict[str, str]]: def register_exception_handlers(app: FastAPI) -> None: + @app.exception_handler(RangeNotSatisfiableError) + async def _handle_range_error(_request: Request, exc: RangeNotSatisfiableError) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + content=_error_body(exc.code, exc.message), + headers={"Content-Range": f"bytes */{exc.total_size}"}, + ) + @app.exception_handler(DomainError) async def _handle_domain_error(_request: Request, exc: DomainError) -> JSONResponse: http_status = _STATUS_BY_ERROR.get(type(exc), status.HTTP_400_BAD_REQUEST) diff --git a/app/api/schemas/upload.py b/app/api/schemas/upload.py new file mode 100644 index 0000000..4a5df3a --- /dev/null +++ b/app/api/schemas/upload.py @@ -0,0 +1,11 @@ +"""Schemas for upload responses.""" + +import uuid + +from pydantic import BaseModel + + +class UploadResponse(BaseModel): + track_id: uuid.UUID + title: str + already_exists: bool diff --git a/app/api/v1/streaming.py b/app/api/v1/streaming.py index 7642730..3e45323 100644 --- a/app/api/v1/streaming.py +++ b/app/api/v1/streaming.py @@ -1,20 +1,38 @@ -"""Audio streaming endpoints: direct stream and HLS.""" +"""Audio streaming endpoint — direct stream with Range support.""" import uuid -from typing import Any +from typing import Annotated -from fastapi import APIRouter +from fastapi import APIRouter, Header +from fastapi.responses import StreamingResponse + +from app.api.deps import StreamingServiceDep router = APIRouter(prefix="/stream", tags=["streaming"]) @router.get("/{track_id}") -async def stream_track(track_id: uuid.UUID) -> Any: ... +async def stream_track( + track_id: uuid.UUID, + service: StreamingServiceDep, + range_header: Annotated[str | None, Header(alias="Range")] = None, +) -> StreamingResponse: + result = await service.open_stream(track_id, range_header) + headers = { + "Accept-Ranges": "bytes", + "Content-Length": str(result.content_length), + } -@router.get("/{track_id}/hls/playlist.m3u8") -async def hls_playlist(track_id: uuid.UUID) -> Any: ... + if result.is_partial: + headers["Content-Range"] = f"bytes {result.start}-{result.end}/{result.total_size}" + status_code = 206 + else: + status_code = 200 - -@router.get("/{track_id}/hls/{segment}") -async def hls_segment(track_id: uuid.UUID, segment: str) -> Any: ... + return StreamingResponse( + result.stream, + status_code=status_code, + headers=headers, + media_type=result.content_type, + ) diff --git a/app/api/v1/upload.py b/app/api/v1/upload.py index e276702..62f8480 100644 --- a/app/api/v1/upload.py +++ b/app/api/v1/upload.py @@ -1,11 +1,24 @@ """Local file upload endpoint.""" -from typing import Any +from typing import Annotated -from fastapi import APIRouter +from fastapi import APIRouter, File, UploadFile + +from app.api.deps import CurrentUser, UploadServiceDep +from app.api.schemas.upload import UploadResponse router = APIRouter(prefix="/upload", tags=["upload"]) -@router.post("") -async def upload_file() -> Any: ... +@router.post("", response_model=UploadResponse) +async def upload_file( + file: Annotated[UploadFile, File()], + current_user: CurrentUser, + service: UploadServiceDep, +) -> UploadResponse: + result = await service.handle_upload(upload=file, user=current_user) + return UploadResponse( + track_id=result.track_id, + title=result.title, + already_exists=result.already_exists, + ) diff --git a/app/application/streaming_service.py b/app/application/streaming_service.py new file mode 100644 index 0000000..34b4c97 --- /dev/null +++ b/app/application/streaming_service.py @@ -0,0 +1,97 @@ +"""StreamingService — resolves a track and opens a byte-range stream.""" + +import re +import uuid +from collections.abc import AsyncIterator +from dataclasses import dataclass + +from app.domain.errors import NotFoundError, RangeNotSatisfiableError +from app.domain.ports import FileStorage, TrackRepository + +_FORMAT_CONTENT_TYPE: dict[str, str] = { + "mp3": "audio/mpeg", + "flac": "audio/flac", + "m4a": "audio/mp4", + "aac": "audio/aac", + "ogg": "audio/ogg", + "opus": "audio/ogg", + "wav": "audio/wav", + "wma": "audio/x-ms-wma", + "aiff": "audio/aiff", + "aif": "audio/aiff", +} + +_RANGE_RE = re.compile(r"bytes=(\d+)-(\d*)") + + +@dataclass +class StreamResult: + stream: AsyncIterator[bytes] + total_size: int + content_length: int + content_type: str + start: int + end: int + is_partial: bool + + +def _parse_range(header: str | None, total_size: int) -> tuple[int, int | None, bool]: + """Return (start, end, is_partial). Raises RangeNotSatisfiableError on invalid range.""" + if header is None: + return 0, None, False + + m = _RANGE_RE.fullmatch(header.strip()) + if not m: + return 0, None, False # malformed → treat as absent per RFC 7233 + + start = int(m.group(1)) + end: int | None = int(m.group(2)) if m.group(2) else None + + if start >= total_size: + raise RangeNotSatisfiableError(total_size) + + if end is not None: + if end >= total_size: + end = total_size - 1 + if end < start: + raise RangeNotSatisfiableError(total_size) + + return start, end, True + + +class StreamingService: + def __init__(self, tracks: TrackRepository, storage: FileStorage) -> None: + self._tracks = tracks + self._storage = storage + + async def open_stream( + self, + track_id: uuid.UUID, + range_header: str | None, + ) -> StreamResult: + track = await self._tracks.get_by_id(track_id) + if track is None: + raise NotFoundError("Track not found.") + + stat = await self._storage.stat(track.file_path) + total_size = stat.size + content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get( + track.file_format.lower(), "application/octet-stream" + ) + + start, end, is_partial = _parse_range(range_header, total_size) + + stream, _ = await self._storage.open_range(track.file_path, start, end) + + actual_end = end if end is not None else total_size - 1 + content_length = actual_end - start + 1 + + return StreamResult( + stream=stream, + total_size=total_size, + content_length=content_length, + content_type=content_type, + start=start, + end=actual_end, + is_partial=is_partial, + ) diff --git a/app/application/upload_service.py b/app/application/upload_service.py new file mode 100644 index 0000000..e4e3760 --- /dev/null +++ b/app/application/upload_service.py @@ -0,0 +1,116 @@ +"""UploadService — handles user file uploads.""" + +import contextlib +import hashlib +import os +import tempfile +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol + +import anyio + +from app.domain.entities.user import User +from app.domain.ports import ArtistRepository, FileStorage, TrackRepository + + +class UploadFileProtocol(Protocol): + filename: str | None + + async def read(self, size: int = -1) -> bytes: ... + + +@dataclass(frozen=True) +class UploadResult: + track_id: uuid.UUID + title: str + already_exists: bool + + +async def _stream_to_temp(upload: UploadFileProtocol, dest: Path) -> tuple[str, int]: + h = hashlib.sha256() + size = 0 + async with await anyio.open_file(dest, "wb") as out: + while True: + chunk = await upload.read(65536) + if not chunk: + break + h.update(chunk) + await out.write(chunk) + size += len(chunk) + return h.hexdigest(), size + + +class UploadService: + def __init__( + self, + tracks: TrackRepository, + artists: ArtistRepository, + storage: FileStorage, + tmp_dir: Path | None = None, + ) -> None: + self._tracks = tracks + self._artists = artists + self._storage = storage + self._tmp_dir = tmp_dir + + async def handle_upload( + self, + *, + upload: UploadFileProtocol, + user: User, + ) -> UploadResult: + filename = upload.filename or "unknown" + ext = Path(filename).suffix.lower().lstrip(".") or "bin" + title = Path(filename).stem or "Unknown" + + fd, tmp_str = tempfile.mkstemp( + suffix=f".{ext}", + dir=str(self._tmp_dir) if self._tmp_dir else None, + ) + tmp_path = Path(tmp_str) + try: + os.close(fd) + sha256_hex, file_size = await _stream_to_temp(upload, tmp_path) + + existing = await self._tracks.get_by_source("upload", sha256_hex) + if existing is not None: + return UploadResult( + track_id=existing.id, + title=existing.title, + already_exists=True, + ) + + track_id = uuid.uuid4() + key = f"tracks/{str(track_id)[:2]}/{track_id}.{ext}" + + await self._storage.save_file(key, tmp_path) + try: + artist = await self._artists.get_or_create("Unknown Artist") + track = await self._tracks.add( + id=track_id, + title=title, + artist_id=artist.id, + file_path=key, + file_format=ext, + file_size=file_size, + source="upload", + source_id=sha256_hex, + metadata_status="pending", + added_by=user.id, + ) + except Exception: + with contextlib.suppress(Exception): + await self._storage.delete(key) + raise + + # TODO(1D): enqueue metadata enrichment task + + return UploadResult( + track_id=track.id, + title=track.title, + already_exists=False, + ) + finally: + await anyio.Path(tmp_path).unlink(missing_ok=True) diff --git a/app/core/config.py b/app/core/config.py index ac36fa6..d9874f4 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -49,6 +49,15 @@ class Settings(BaseSettings): media_path: Path = Path("/data/media") transcode_cache_path: Path = Path("/data/transcode-cache") max_parallel_downloads: int = 2 + storage_backend: Literal["local", "s3"] = "local" + upload_tmp_dir: Path | None = None + + # -- S3 storage (deferred; set storage_backend="s3" to use) ---------- + s3_endpoint_url: str | None = None + s3_bucket: str | None = None + s3_region: str | None = None + s3_access_key: SecretStr | None = None + s3_secret_key: SecretStr | None = None # -- external services (all optional; graceful degradation) ---------- ml_service_url: str | None = None diff --git a/app/core/hashing.py b/app/core/hashing.py new file mode 100644 index 0000000..33ea29e --- /dev/null +++ b/app/core/hashing.py @@ -0,0 +1,12 @@ +"""File hashing utilities.""" + +import hashlib +from pathlib import Path + + +def sha256_of_file(path: Path) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() diff --git a/app/domain/entities/__init__.py b/app/domain/entities/__init__.py index 4b76ded..5b9ff76 100644 --- a/app/domain/entities/__init__.py +++ b/app/domain/entities/__init__.py @@ -1,5 +1,7 @@ """Domain entities and value objects — pure, framework-free.""" +from app.domain.entities.storage import ObjectStat +from app.domain.entities.track import Artist, Track from app.domain.entities.user import Credentials, User -__all__ = ["Credentials", "User"] +__all__ = ["Artist", "Credentials", "ObjectStat", "Track", "User"] diff --git a/app/domain/entities/storage.py b/app/domain/entities/storage.py new file mode 100644 index 0000000..8a00444 --- /dev/null +++ b/app/domain/entities/storage.py @@ -0,0 +1,9 @@ +"""Value objects for file storage.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class ObjectStat: + size: int + content_type: str | None diff --git a/app/domain/entities/track.py b/app/domain/entities/track.py new file mode 100644 index 0000000..dc4debb --- /dev/null +++ b/app/domain/entities/track.py @@ -0,0 +1,29 @@ +"""Track and Artist domain entities.""" + +import datetime as dt +import uuid +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class Artist: + id: uuid.UUID + name: str + created_at: dt.datetime + updated_at: dt.datetime + + +@dataclass(frozen=True, slots=True) +class Track: + id: uuid.UUID + title: str + artist_id: uuid.UUID + file_path: str + file_format: str + file_size: int + source: str + source_id: str + duration_seconds: int | None + metadata_status: str + created_at: dt.datetime + updated_at: dt.datetime diff --git a/app/domain/errors.py b/app/domain/errors.py index 62360d0..a400240 100644 --- a/app/domain/errors.py +++ b/app/domain/errors.py @@ -61,3 +61,19 @@ class DependencyUnavailableError(DomainError): """ code = "dependency_unavailable" + + +class StorageError(DomainError): + """File storage operation failed.""" + + code = "storage_error" + + +class RangeNotSatisfiableError(DomainError): + """Requested byte range cannot be satisfied.""" + + code = "range_not_satisfiable" + + def __init__(self, total_size: int) -> None: + super().__init__("Requested range is not satisfiable.") + self.total_size = total_size diff --git a/app/domain/ports.py b/app/domain/ports.py index 6050f67..3b90adf 100644 --- a/app/domain/ports.py +++ b/app/domain/ports.py @@ -7,9 +7,13 @@ are bound to these ports at the composition root (``app.api.deps``). import datetime as dt import uuid +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager +from pathlib import Path from typing import Protocol -from app.domain.entities import Credentials, User +from app.domain.entities import Credentials, ObjectStat, User +from app.domain.entities.track import Artist, Track from app.domain.tokens import IssuedToken, TokenClaims, TokenType @@ -56,3 +60,38 @@ class TokenService(Protocol): """Verify signature + expiry and return claims. Raises :class:`~app.domain.errors.AuthenticationError` on any failure.""" ... + + +class FileStorage(Protocol): + async def save_file(self, key: str, src_path: Path) -> int: ... + async def open_range( + self, key: str, start: int, end: int | None + ) -> tuple[AsyncIterator[bytes], int]: ... + async def stat(self, key: str) -> ObjectStat: ... + async def exists(self, key: str) -> bool: ... + async def delete(self, key: str) -> None: ... + def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]: ... + + +class ArtistRepository(Protocol): + async def get_or_create(self, name: str) -> Artist: ... + + +class TrackRepository(Protocol): + async def get_by_id(self, track_id: uuid.UUID) -> Track | None: ... + async def get_by_source(self, source: str, source_id: str) -> Track | None: ... + async def add( + self, + *, + id: uuid.UUID, + title: str, + artist_id: uuid.UUID, + file_path: str, + file_format: str, + file_size: int, + source: str, + source_id: str, + metadata_status: str, + added_by: uuid.UUID | None, + ) -> Track: ... + async def delete(self, track_id: uuid.UUID) -> None: ... diff --git a/app/infrastructure/db/repositories/__init__.py b/app/infrastructure/db/repositories/__init__.py index 511098b..f6f8a92 100644 --- a/app/infrastructure/db/repositories/__init__.py +++ b/app/infrastructure/db/repositories/__init__.py @@ -1,8 +1,15 @@ """SQLAlchemy repository adapters implementing the domain ports.""" +from app.infrastructure.db.repositories.artist_repository import SqlAlchemyArtistRepository from app.infrastructure.db.repositories.refresh_token_repository import ( SqlAlchemyRefreshTokenRepository, ) +from app.infrastructure.db.repositories.track_repository import SqlAlchemyTrackRepository from app.infrastructure.db.repositories.user_repository import SqlAlchemyUserRepository -__all__ = ["SqlAlchemyRefreshTokenRepository", "SqlAlchemyUserRepository"] +__all__ = [ + "SqlAlchemyArtistRepository", + "SqlAlchemyRefreshTokenRepository", + "SqlAlchemyTrackRepository", + "SqlAlchemyUserRepository", +] diff --git a/app/infrastructure/db/repositories/artist_repository.py b/app/infrastructure/db/repositories/artist_repository.py new file mode 100644 index 0000000..15195e2 --- /dev/null +++ b/app/infrastructure/db/repositories/artist_repository.py @@ -0,0 +1,32 @@ +"""Artist repository — adapter over ``AsyncSession``.""" + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.domain.entities.track import Artist +from app.infrastructure.db.models.artist import ArtistModel + + +def _to_entity(row: ArtistModel) -> Artist: + return Artist( + id=row.id, + name=row.name, + created_at=row.created_at, + updated_at=row.updated_at, + ) + + +class SqlAlchemyArtistRepository: + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get_or_create(self, name: str) -> Artist: + row = ( + await self._session.execute(select(ArtistModel).where(ArtistModel.name == name)) + ).scalar_one_or_none() + if row is None: + row = ArtistModel(name=name) + self._session.add(row) + await self._session.flush() + await self._session.refresh(row) + return _to_entity(row) diff --git a/app/infrastructure/db/repositories/track_repository.py b/app/infrastructure/db/repositories/track_repository.py new file mode 100644 index 0000000..809366b --- /dev/null +++ b/app/infrastructure/db/repositories/track_repository.py @@ -0,0 +1,83 @@ +"""Track repository — adapter over ``AsyncSession``.""" + +import uuid + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.domain.entities.track import Track +from app.infrastructure.db.models.track import TrackModel + + +def _to_entity(row: TrackModel) -> Track: + return Track( + id=row.id, + title=row.title, + artist_id=row.artist_id, + file_path=row.file_path, + file_format=row.file_format, + file_size=row.file_size, + source=row.source, + source_id=row.source_id, + duration_seconds=row.duration_seconds, + metadata_status=row.metadata_status, + created_at=row.created_at, + updated_at=row.updated_at, + ) + + +class SqlAlchemyTrackRepository: + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get_by_id(self, track_id: uuid.UUID) -> Track | None: + row = await self._session.get(TrackModel, track_id) + return _to_entity(row) if row is not None else None + + async def get_by_source(self, source: str, source_id: str) -> Track | None: + row = ( + await self._session.execute( + select(TrackModel).where( + TrackModel.source == source, + TrackModel.source_id == source_id, + ) + ) + ).scalar_one_or_none() + return _to_entity(row) if row is not None else None + + async def add( + self, + *, + id: uuid.UUID, + title: str, + artist_id: uuid.UUID, + file_path: str, + file_format: str, + file_size: int, + source: str, + source_id: str, + metadata_status: str, + added_by: uuid.UUID | None, + ) -> Track: + row = TrackModel( + id=id, + title=title, + artist_id=artist_id, + file_path=file_path, + file_format=file_format, + file_size=file_size, + source=source, + source_id=source_id, + metadata_status=metadata_status, + added_by=added_by, + ) + self._session.add(row) + await self._session.flush() + await self._session.refresh(row) + return _to_entity(row) + + async def delete(self, track_id: uuid.UUID) -> None: + row = await self._session.get(TrackModel, track_id) + if row is not None: + await self._session.delete(row) + await self._session.flush() diff --git a/app/infrastructure/storage/__init__.py b/app/infrastructure/storage/__init__.py new file mode 100644 index 0000000..9922992 --- /dev/null +++ b/app/infrastructure/storage/__init__.py @@ -0,0 +1 @@ +"""File storage adapters.""" diff --git a/app/infrastructure/storage/local.py b/app/infrastructure/storage/local.py new file mode 100644 index 0000000..cfa456e --- /dev/null +++ b/app/infrastructure/storage/local.py @@ -0,0 +1,86 @@ +"""LocalFileStorage — stores files on the local filesystem.""" + +import os +import shutil +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from pathlib import Path + +import anyio + +from app.domain.entities.storage import ObjectStat +from app.domain.errors import StorageError + +_EXT_CONTENT_TYPE: dict[str, str] = { + "mp3": "audio/mpeg", + "flac": "audio/flac", + "m4a": "audio/mp4", + "aac": "audio/aac", + "ogg": "audio/ogg", + "opus": "audio/ogg", + "wav": "audio/wav", + "wma": "audio/x-ms-wma", + "aiff": "audio/aiff", + "aif": "audio/aiff", +} + + +class LocalFileStorage: + def __init__(self, media_path: Path) -> None: + self._media_path = media_path + + async def save_file(self, key: str, src_path: Path) -> int: + dest = self._media_path / key + dest.parent.mkdir(parents=True, exist_ok=True) + part = dest.with_suffix(dest.suffix + ".part") + shutil.copyfile(str(src_path), str(part)) + os.replace(str(part), str(dest)) + return dest.stat().st_size + + async def open_range( + self, key: str, start: int, end: int | None + ) -> tuple[AsyncIterator[bytes], int]: + path = self._media_path / key + if not path.exists(): + raise StorageError(f"Object not found: {key}") + total_size = path.stat().st_size + + _start = start + _end = end + _total_size = total_size + _path = path + + async def _iter() -> AsyncGenerator[bytes]: + async with await anyio.open_file(_path, "rb") as f: + await f.seek(_start) + remaining = (_end - _start + 1) if _end is not None else (_total_size - _start) + while remaining > 0: + chunk: bytes = await f.read(min(65536, remaining)) + if not chunk: + break + yield chunk + remaining -= len(chunk) + + aiter: AsyncIterator[bytes] = _iter() + return aiter, total_size + + async def stat(self, key: str) -> ObjectStat: + path = self._media_path / key + if not path.exists(): + raise StorageError(f"Object not found: {key}") + st = path.stat() + ext = path.suffix.lower().lstrip(".") + return ObjectStat(size=st.st_size, content_type=_EXT_CONTENT_TYPE.get(ext)) + + async def exists(self, key: str) -> bool: + return (self._media_path / key).exists() + + async def delete(self, key: str) -> None: + (self._media_path / key).unlink(missing_ok=True) + + def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]: + return self._as_local_path_cm(key) + + @asynccontextmanager + async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]: + yield self._media_path / key diff --git a/app/infrastructure/storage/provider.py b/app/infrastructure/storage/provider.py new file mode 100644 index 0000000..72d4742 --- /dev/null +++ b/app/infrastructure/storage/provider.py @@ -0,0 +1,16 @@ +"""File storage provider — singleton factory.""" + +from app.core.config import get_settings +from app.infrastructure.storage.local import LocalFileStorage + +_storage: LocalFileStorage | None = None + + +def get_file_storage() -> LocalFileStorage: + global _storage + if _storage is None: + settings = get_settings() + if settings.storage_backend == "s3": + raise NotImplementedError("S3 storage not yet implemented.") + _storage = LocalFileStorage(settings.media_path) + return _storage diff --git a/tests/test_auth_api.py b/tests/test_auth_api.py index 54b4f91..41410c3 100644 --- a/tests/test_auth_api.py +++ b/tests/test_auth_api.py @@ -10,7 +10,7 @@ from collections.abc import AsyncIterator import pytest from app.core.security import Argon2PasswordHasher -from app.infrastructure.db import Base, get_engine, session_scope +from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope from app.infrastructure.db.repositories import ( SqlAlchemyRefreshTokenRepository, SqlAlchemyUserRepository, @@ -73,6 +73,7 @@ async def api() -> AsyncIterator[AsyncClient]: async with get_engine().begin() as conn: await conn.run_sync(Base.metadata.drop_all) + await dispose_engine() async def _login(api: AsyncClient, username: str, password: str) -> tuple[str, str]: diff --git a/tests/test_storage_local.py b/tests/test_storage_local.py new file mode 100644 index 0000000..9628f63 --- /dev/null +++ b/tests/test_storage_local.py @@ -0,0 +1,104 @@ +"""Unit tests for LocalFileStorage.""" + +import pytest +from app.infrastructure.storage.local import LocalFileStorage + +pytestmark = pytest.mark.asyncio + + +async def test_save_and_stat(tmp_path): + storage = LocalFileStorage(tmp_path) + src = tmp_path / "src.mp3" + src.write_bytes(b"test audio data") + + size = await storage.save_file("tracks/te/test.mp3", src) + assert size == 15 + + stat = await storage.stat("tracks/te/test.mp3") + assert stat.size == 15 + assert stat.content_type == "audio/mpeg" + + +async def test_save_creates_parent_dirs(tmp_path): + storage = LocalFileStorage(tmp_path) + src = tmp_path / "src.flac" + src.write_bytes(b"x") + + await storage.save_file("tracks/ab/cdef.flac", src) + assert (tmp_path / "tracks" / "ab" / "cdef.flac").exists() + + +async def test_open_range_full(tmp_path): + storage = LocalFileStorage(tmp_path) + data = b"hello world" * 100 + src = tmp_path / "src.flac" + src.write_bytes(data) + await storage.save_file("tracks/he/hello.flac", src) + + stream, total = await storage.open_range("tracks/he/hello.flac", 0, None) + result = b"".join([chunk async for chunk in stream]) + assert result == data + assert total == len(data) + + +async def test_open_range_partial(tmp_path): + storage = LocalFileStorage(tmp_path) + data = b"0123456789" + src = tmp_path / "src.mp3" + src.write_bytes(data) + await storage.save_file("tracks/sr/src.mp3", src) + + stream, total = await storage.open_range("tracks/sr/src.mp3", 3, 7) + result = b"".join([chunk async for chunk in stream]) + assert result == b"34567" + assert total == 10 + + +async def test_open_range_from_offset_to_end(tmp_path): + storage = LocalFileStorage(tmp_path) + data = b"abcdefghij" + src = tmp_path / "src.wav" + src.write_bytes(data) + await storage.save_file("tracks/sr/src.wav", src) + + stream, total = await storage.open_range("tracks/sr/src.wav", 5, None) + result = b"".join([chunk async for chunk in stream]) + assert result == b"fghij" + assert total == 10 + + +async def test_exists_and_delete(tmp_path): + storage = LocalFileStorage(tmp_path) + src = tmp_path / "src.ogg" + src.write_bytes(b"ogg data") + await storage.save_file("tracks/sr/src.ogg", src) + + assert await storage.exists("tracks/sr/src.ogg") is True + await storage.delete("tracks/sr/src.ogg") + assert await storage.exists("tracks/sr/src.ogg") is False + + +async def test_delete_missing_is_noop(tmp_path): + storage = LocalFileStorage(tmp_path) + await storage.delete("tracks/no/nope.mp3") + + +async def test_as_local_path(tmp_path): + storage = LocalFileStorage(tmp_path) + src = tmp_path / "src.mp3" + src.write_bytes(b"local bytes") + await storage.save_file("tracks/lo/local.mp3", src) + + async with storage.as_local_path("tracks/lo/local.mp3") as path: + assert path.read_bytes() == b"local bytes" + + +async def test_stat_unknown_extension(tmp_path): + storage = LocalFileStorage(tmp_path) + src = tmp_path / "src.xyz" + src.write_bytes(b"mystery") + await storage.save_file("tracks/my/mystery.xyz", src) + + stat = await storage.stat("tracks/my/mystery.xyz") + assert stat.size == 7 + assert stat.content_type is None diff --git a/tests/test_upload_stream_api.py b/tests/test_upload_stream_api.py new file mode 100644 index 0000000..a5a5a3f --- /dev/null +++ b/tests/test_upload_stream_api.py @@ -0,0 +1,185 @@ +"""Integration tests for upload and streaming endpoints. + +Requires a reachable Postgres; skips otherwise. +""" + +import asyncio +import os +from collections.abc import AsyncIterator +from pathlib import Path + +import pytest +from app.core.config import get_settings +from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope +from app.infrastructure.db.repositories import ( + SqlAlchemyRefreshTokenRepository, + SqlAlchemyUserRepository, +) +from asgi_lifespan import LifespanManager +from httpx import ASGITransport, AsyncClient + +pytestmark = pytest.mark.asyncio + +_db_reachable_cache: bool | None = None + + +async def _db_reachable() -> bool: + global _db_reachable_cache + if _db_reachable_cache is not None: + return _db_reachable_cache + + from sqlalchemy import text + + try: + async with asyncio.timeout(3): + async with get_engine().connect() as conn: + await conn.execute(text("SELECT 1")) + _db_reachable_cache = True + except Exception: + _db_reachable_cache = False + return _db_reachable_cache + + +@pytest.fixture +async def api(tmp_path: Path) -> AsyncIterator[AsyncClient]: + if not await _db_reachable(): + pytest.skip("Postgres not reachable — integration test skipped.") + + os.environ["MEDIA_PATH"] = str(tmp_path) + get_settings.cache_clear() + + # Also reset the file storage singleton so it picks up the new media_path + import app.infrastructure.storage.provider as _storage_provider + + _storage_provider._storage = None + + try: + async with get_engine().begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + from app.application.user_service import UserService + from app.core.security import Argon2PasswordHasher + + async with session_scope() as session: + await UserService( + users=SqlAlchemyUserRepository(session), + refresh_tokens=SqlAlchemyRefreshTokenRepository(session), + hasher=Argon2PasswordHasher(), + ).create_user(username="testuser", password="testpass1", is_superuser=False) + + from app.main import create_app + + app = create_app() + async with LifespanManager(app): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + async with get_engine().begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await dispose_engine() + finally: + _storage_provider._storage = None + os.environ.pop("MEDIA_PATH", None) + get_settings.cache_clear() + + +async def _login(api: AsyncClient) -> str: + resp = await api.post( + "/api/v1/auth/login", json={"username": "testuser", "password": "testpass1"} + ) + assert resp.status_code == 200 + return str(resp.json()["access_token"]) + + +async def test_upload_creates_track(api: AsyncClient) -> None: + token = await _login(api) + audio = b"fake mp3 bytes" * 100 + + resp = await api.post( + "/api/v1/upload", + files={"file": ("song.mp3", audio, "audio/mpeg")}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["already_exists"] is False + assert body["title"] == "song" + assert "track_id" in body + + +async def test_upload_dedup(api: AsyncClient) -> None: + token = await _login(api) + audio = b"same content" * 50 + + first = await api.post( + "/api/v1/upload", + files={"file": ("a.mp3", audio, "audio/mpeg")}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert first.status_code == 200 + assert first.json()["already_exists"] is False + + second = await api.post( + "/api/v1/upload", + files={"file": ("b.mp3", audio, "audio/mpeg")}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert second.status_code == 200 + assert second.json()["already_exists"] is True + assert second.json()["track_id"] == first.json()["track_id"] + + +async def test_stream_full(api: AsyncClient) -> None: + token = await _login(api) + audio = b"audio data for streaming" * 10 + + up = await api.post( + "/api/v1/upload", + files={"file": ("track.mp3", audio, "audio/mpeg")}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert up.status_code == 200 + track_id = up.json()["track_id"] + + resp = await api.get(f"/api/v1/stream/{track_id}") + assert resp.status_code == 200 + assert resp.content == audio + assert resp.headers["content-type"].startswith("audio/mpeg") + assert "accept-ranges" in resp.headers + + +async def test_stream_range(api: AsyncClient) -> None: + token = await _login(api) + audio = b"0123456789" * 10 + + up = await api.post( + "/api/v1/upload", + files={"file": ("range.mp3", audio, "audio/mpeg")}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert up.status_code == 200 + track_id = up.json()["track_id"] + + resp = await api.get( + f"/api/v1/stream/{track_id}", + headers={"Range": "bytes=0-9"}, + ) + assert resp.status_code == 206 + assert resp.content == b"0123456789" + assert resp.headers["content-range"] == f"bytes 0-9/{len(audio)}" + assert resp.headers["content-length"] == "10" + + +async def test_stream_not_found(api: AsyncClient) -> None: + resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000") + assert resp.status_code == 404 + + +async def test_upload_requires_auth(api: AsyncClient) -> None: + resp = await api.post( + "/api/v1/upload", + files={"file": ("x.mp3", b"data", "audio/mpeg")}, + ) + assert resp.status_code == 401