feat: local storage logic & endpoints
This commit is contained in:
+31
-1
@@ -15,17 +15,22 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.application.auth_service import AuthService
|
||||
from app.application.streaming_service import StreamingService
|
||||
from app.application.upload_service import UploadService
|
||||
from app.application.user_service import UserService
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import Argon2PasswordHasher, JwtTokenService
|
||||
from app.domain.entities import User
|
||||
from app.domain.errors import AuthenticationError, PermissionDeniedError
|
||||
from app.domain.ports import PasswordHasher, TokenService
|
||||
from app.domain.ports import FileStorage, PasswordHasher, TokenService
|
||||
from app.infrastructure.db import get_sessionmaker
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyArtistRepository,
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyTrackRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
)
|
||||
from app.infrastructure.storage.provider import get_file_storage
|
||||
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
@@ -77,6 +82,31 @@ AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
|
||||
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
|
||||
|
||||
|
||||
# -- file storage (process-cached) ---------------------------------------------
|
||||
FileStorageDep = Annotated[FileStorage, Depends(get_file_storage)]
|
||||
|
||||
|
||||
def get_upload_service(session: SessionDep, storage: FileStorageDep) -> UploadService:
|
||||
settings = get_settings()
|
||||
return UploadService(
|
||||
tracks=SqlAlchemyTrackRepository(session),
|
||||
artists=SqlAlchemyArtistRepository(session),
|
||||
storage=storage,
|
||||
tmp_dir=settings.upload_tmp_dir,
|
||||
)
|
||||
|
||||
|
||||
def get_streaming_service(session: SessionDep, storage: FileStorageDep) -> StreamingService:
|
||||
return StreamingService(
|
||||
tracks=SqlAlchemyTrackRepository(session),
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
|
||||
UploadServiceDep = Annotated[UploadService, Depends(get_upload_service)]
|
||||
StreamingServiceDep = Annotated[StreamingService, Depends(get_streaming_service)]
|
||||
|
||||
|
||||
# -- current user / authorization ----------------------------------------------
|
||||
# auto_error=False: we raise domain AuthenticationError (mapped to 401) so the
|
||||
# error envelope stays consistent with the rest of the API.
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.domain.errors import (
|
||||
DomainError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RangeNotSatisfiableError,
|
||||
StorageError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
@@ -25,6 +27,7 @@ _STATUS_BY_ERROR: dict[type[DomainError], int] = {
|
||||
AuthenticationError: status.HTTP_401_UNAUTHORIZED,
|
||||
PermissionDeniedError: status.HTTP_403_FORBIDDEN,
|
||||
DependencyUnavailableError: status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
StorageError: status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +36,14 @@ def _error_body(code: str, message: str) -> dict[str, dict[str, str]]:
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
@app.exception_handler(RangeNotSatisfiableError)
|
||||
async def _handle_range_error(_request: Request, exc: RangeNotSatisfiableError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
content=_error_body(exc.code, exc.message),
|
||||
headers={"Content-Range": f"bytes */{exc.total_size}"},
|
||||
)
|
||||
|
||||
@app.exception_handler(DomainError)
|
||||
async def _handle_domain_error(_request: Request, exc: DomainError) -> JSONResponse:
|
||||
http_status = _STATUS_BY_ERROR.get(type(exc), status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Schemas for upload responses."""
|
||||
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
track_id: uuid.UUID
|
||||
title: str
|
||||
already_exists: bool
|
||||
+27
-9
@@ -1,20 +1,38 @@
|
||||
"""Audio streaming endpoints: direct stream and HLS."""
|
||||
"""Audio streaming endpoint — direct stream with Range support."""
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.api.deps import StreamingServiceDep
|
||||
|
||||
router = APIRouter(prefix="/stream", tags=["streaming"])
|
||||
|
||||
|
||||
@router.get("/{track_id}")
|
||||
async def stream_track(track_id: uuid.UUID) -> Any: ...
|
||||
async def stream_track(
|
||||
track_id: uuid.UUID,
|
||||
service: StreamingServiceDep,
|
||||
range_header: Annotated[str | None, Header(alias="Range")] = None,
|
||||
) -> StreamingResponse:
|
||||
result = await service.open_stream(track_id, range_header)
|
||||
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(result.content_length),
|
||||
}
|
||||
|
||||
@router.get("/{track_id}/hls/playlist.m3u8")
|
||||
async def hls_playlist(track_id: uuid.UUID) -> Any: ...
|
||||
if result.is_partial:
|
||||
headers["Content-Range"] = f"bytes {result.start}-{result.end}/{result.total_size}"
|
||||
status_code = 206
|
||||
else:
|
||||
status_code = 200
|
||||
|
||||
|
||||
@router.get("/{track_id}/hls/{segment}")
|
||||
async def hls_segment(track_id: uuid.UUID, segment: str) -> Any: ...
|
||||
return StreamingResponse(
|
||||
result.stream,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
media_type=result.content_type,
|
||||
)
|
||||
|
||||
+17
-4
@@ -1,11 +1,24 @@
|
||||
"""Local file upload endpoint."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
|
||||
from app.api.deps import CurrentUser, UploadServiceDep
|
||||
from app.api.schemas.upload import UploadResponse
|
||||
|
||||
router = APIRouter(prefix="/upload", tags=["upload"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_file() -> Any: ...
|
||||
@router.post("", response_model=UploadResponse)
|
||||
async def upload_file(
|
||||
file: Annotated[UploadFile, File()],
|
||||
current_user: CurrentUser,
|
||||
service: UploadServiceDep,
|
||||
) -> UploadResponse:
|
||||
result = await service.handle_upload(upload=file, user=current_user)
|
||||
return UploadResponse(
|
||||
track_id=result.track_id,
|
||||
title=result.title,
|
||||
already_exists=result.already_exists,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""StreamingService — resolves a track and opens a byte-range stream."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.errors import NotFoundError, RangeNotSatisfiableError
|
||||
from app.domain.ports import FileStorage, TrackRepository
|
||||
|
||||
_FORMAT_CONTENT_TYPE: dict[str, str] = {
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
"aac": "audio/aac",
|
||||
"ogg": "audio/ogg",
|
||||
"opus": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"aiff": "audio/aiff",
|
||||
"aif": "audio/aiff",
|
||||
}
|
||||
|
||||
_RANGE_RE = re.compile(r"bytes=(\d+)-(\d*)")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamResult:
|
||||
stream: AsyncIterator[bytes]
|
||||
total_size: int
|
||||
content_length: int
|
||||
content_type: str
|
||||
start: int
|
||||
end: int
|
||||
is_partial: bool
|
||||
|
||||
|
||||
def _parse_range(header: str | None, total_size: int) -> tuple[int, int | None, bool]:
|
||||
"""Return (start, end, is_partial). Raises RangeNotSatisfiableError on invalid range."""
|
||||
if header is None:
|
||||
return 0, None, False
|
||||
|
||||
m = _RANGE_RE.fullmatch(header.strip())
|
||||
if not m:
|
||||
return 0, None, False # malformed → treat as absent per RFC 7233
|
||||
|
||||
start = int(m.group(1))
|
||||
end: int | None = int(m.group(2)) if m.group(2) else None
|
||||
|
||||
if start >= total_size:
|
||||
raise RangeNotSatisfiableError(total_size)
|
||||
|
||||
if end is not None:
|
||||
if end >= total_size:
|
||||
end = total_size - 1
|
||||
if end < start:
|
||||
raise RangeNotSatisfiableError(total_size)
|
||||
|
||||
return start, end, True
|
||||
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self, tracks: TrackRepository, storage: FileStorage) -> None:
|
||||
self._tracks = tracks
|
||||
self._storage = storage
|
||||
|
||||
async def open_stream(
|
||||
self,
|
||||
track_id: uuid.UUID,
|
||||
range_header: str | None,
|
||||
) -> StreamResult:
|
||||
track = await self._tracks.get_by_id(track_id)
|
||||
if track is None:
|
||||
raise NotFoundError("Track not found.")
|
||||
|
||||
stat = await self._storage.stat(track.file_path)
|
||||
total_size = stat.size
|
||||
content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get(
|
||||
track.file_format.lower(), "application/octet-stream"
|
||||
)
|
||||
|
||||
start, end, is_partial = _parse_range(range_header, total_size)
|
||||
|
||||
stream, _ = await self._storage.open_range(track.file_path, start, end)
|
||||
|
||||
actual_end = end if end is not None else total_size - 1
|
||||
content_length = actual_end - start + 1
|
||||
|
||||
return StreamResult(
|
||||
stream=stream,
|
||||
total_size=total_size,
|
||||
content_length=content_length,
|
||||
content_type=content_type,
|
||||
start=start,
|
||||
end=actual_end,
|
||||
is_partial=is_partial,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""UploadService — handles user file uploads."""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
import anyio
|
||||
|
||||
from app.domain.entities.user import User
|
||||
from app.domain.ports import ArtistRepository, FileStorage, TrackRepository
|
||||
|
||||
|
||||
class UploadFileProtocol(Protocol):
|
||||
filename: str | None
|
||||
|
||||
async def read(self, size: int = -1) -> bytes: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
track_id: uuid.UUID
|
||||
title: str
|
||||
already_exists: bool
|
||||
|
||||
|
||||
async def _stream_to_temp(upload: UploadFileProtocol, dest: Path) -> tuple[str, int]:
|
||||
h = hashlib.sha256()
|
||||
size = 0
|
||||
async with await anyio.open_file(dest, "wb") as out:
|
||||
while True:
|
||||
chunk = await upload.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
await out.write(chunk)
|
||||
size += len(chunk)
|
||||
return h.hexdigest(), size
|
||||
|
||||
|
||||
class UploadService:
|
||||
def __init__(
|
||||
self,
|
||||
tracks: TrackRepository,
|
||||
artists: ArtistRepository,
|
||||
storage: FileStorage,
|
||||
tmp_dir: Path | None = None,
|
||||
) -> None:
|
||||
self._tracks = tracks
|
||||
self._artists = artists
|
||||
self._storage = storage
|
||||
self._tmp_dir = tmp_dir
|
||||
|
||||
async def handle_upload(
|
||||
self,
|
||||
*,
|
||||
upload: UploadFileProtocol,
|
||||
user: User,
|
||||
) -> UploadResult:
|
||||
filename = upload.filename or "unknown"
|
||||
ext = Path(filename).suffix.lower().lstrip(".") or "bin"
|
||||
title = Path(filename).stem or "Unknown"
|
||||
|
||||
fd, tmp_str = tempfile.mkstemp(
|
||||
suffix=f".{ext}",
|
||||
dir=str(self._tmp_dir) if self._tmp_dir else None,
|
||||
)
|
||||
tmp_path = Path(tmp_str)
|
||||
try:
|
||||
os.close(fd)
|
||||
sha256_hex, file_size = await _stream_to_temp(upload, tmp_path)
|
||||
|
||||
existing = await self._tracks.get_by_source("upload", sha256_hex)
|
||||
if existing is not None:
|
||||
return UploadResult(
|
||||
track_id=existing.id,
|
||||
title=existing.title,
|
||||
already_exists=True,
|
||||
)
|
||||
|
||||
track_id = uuid.uuid4()
|
||||
key = f"tracks/{str(track_id)[:2]}/{track_id}.{ext}"
|
||||
|
||||
await self._storage.save_file(key, tmp_path)
|
||||
try:
|
||||
artist = await self._artists.get_or_create("Unknown Artist")
|
||||
track = await self._tracks.add(
|
||||
id=track_id,
|
||||
title=title,
|
||||
artist_id=artist.id,
|
||||
file_path=key,
|
||||
file_format=ext,
|
||||
file_size=file_size,
|
||||
source="upload",
|
||||
source_id=sha256_hex,
|
||||
metadata_status="pending",
|
||||
added_by=user.id,
|
||||
)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._storage.delete(key)
|
||||
raise
|
||||
|
||||
# TODO(1D): enqueue metadata enrichment task
|
||||
|
||||
return UploadResult(
|
||||
track_id=track.id,
|
||||
title=track.title,
|
||||
already_exists=False,
|
||||
)
|
||||
finally:
|
||||
await anyio.Path(tmp_path).unlink(missing_ok=True)
|
||||
@@ -49,6 +49,15 @@ class Settings(BaseSettings):
|
||||
media_path: Path = Path("/data/media")
|
||||
transcode_cache_path: Path = Path("/data/transcode-cache")
|
||||
max_parallel_downloads: int = 2
|
||||
storage_backend: Literal["local", "s3"] = "local"
|
||||
upload_tmp_dir: Path | None = None
|
||||
|
||||
# -- S3 storage (deferred; set storage_backend="s3" to use) ----------
|
||||
s3_endpoint_url: str | None = None
|
||||
s3_bucket: str | None = None
|
||||
s3_region: str | None = None
|
||||
s3_access_key: SecretStr | None = None
|
||||
s3_secret_key: SecretStr | None = None
|
||||
|
||||
# -- external services (all optional; graceful degradation) ----------
|
||||
ml_service_url: str | None = None
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
"""File hashing utilities."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def sha256_of_file(path: Path) -> str:
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Domain entities and value objects — pure, framework-free."""
|
||||
|
||||
from app.domain.entities.storage import ObjectStat
|
||||
from app.domain.entities.track import Artist, Track
|
||||
from app.domain.entities.user import Credentials, User
|
||||
|
||||
__all__ = ["Credentials", "User"]
|
||||
__all__ = ["Artist", "Credentials", "ObjectStat", "Track", "User"]
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Value objects for file storage."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ObjectStat:
|
||||
size: int
|
||||
content_type: str | None
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Track and Artist domain entities."""
|
||||
|
||||
import datetime as dt
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Artist:
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
created_at: dt.datetime
|
||||
updated_at: dt.datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Track:
|
||||
id: uuid.UUID
|
||||
title: str
|
||||
artist_id: uuid.UUID
|
||||
file_path: str
|
||||
file_format: str
|
||||
file_size: int
|
||||
source: str
|
||||
source_id: str
|
||||
duration_seconds: int | None
|
||||
metadata_status: str
|
||||
created_at: dt.datetime
|
||||
updated_at: dt.datetime
|
||||
@@ -61,3 +61,19 @@ class DependencyUnavailableError(DomainError):
|
||||
"""
|
||||
|
||||
code = "dependency_unavailable"
|
||||
|
||||
|
||||
class StorageError(DomainError):
|
||||
"""File storage operation failed."""
|
||||
|
||||
code = "storage_error"
|
||||
|
||||
|
||||
class RangeNotSatisfiableError(DomainError):
|
||||
"""Requested byte range cannot be satisfied."""
|
||||
|
||||
code = "range_not_satisfiable"
|
||||
|
||||
def __init__(self, total_size: int) -> None:
|
||||
super().__init__("Requested range is not satisfiable.")
|
||||
self.total_size = total_size
|
||||
|
||||
+40
-1
@@ -7,9 +7,13 @@ are bound to these ports at the composition root (``app.api.deps``).
|
||||
|
||||
import datetime as dt
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from app.domain.entities import Credentials, User
|
||||
from app.domain.entities import Credentials, ObjectStat, User
|
||||
from app.domain.entities.track import Artist, Track
|
||||
from app.domain.tokens import IssuedToken, TokenClaims, TokenType
|
||||
|
||||
|
||||
@@ -56,3 +60,38 @@ class TokenService(Protocol):
|
||||
"""Verify signature + expiry and return claims. Raises
|
||||
:class:`~app.domain.errors.AuthenticationError` on any failure."""
|
||||
...
|
||||
|
||||
|
||||
class FileStorage(Protocol):
|
||||
async def save_file(self, key: str, src_path: Path) -> int: ...
|
||||
async def open_range(
|
||||
self, key: str, start: int, end: int | None
|
||||
) -> tuple[AsyncIterator[bytes], int]: ...
|
||||
async def stat(self, key: str) -> ObjectStat: ...
|
||||
async def exists(self, key: str) -> bool: ...
|
||||
async def delete(self, key: str) -> None: ...
|
||||
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]: ...
|
||||
|
||||
|
||||
class ArtistRepository(Protocol):
|
||||
async def get_or_create(self, name: str) -> Artist: ...
|
||||
|
||||
|
||||
class TrackRepository(Protocol):
|
||||
async def get_by_id(self, track_id: uuid.UUID) -> Track | None: ...
|
||||
async def get_by_source(self, source: str, source_id: str) -> Track | None: ...
|
||||
async def add(
|
||||
self,
|
||||
*,
|
||||
id: uuid.UUID,
|
||||
title: str,
|
||||
artist_id: uuid.UUID,
|
||||
file_path: str,
|
||||
file_format: str,
|
||||
file_size: int,
|
||||
source: str,
|
||||
source_id: str,
|
||||
metadata_status: str,
|
||||
added_by: uuid.UUID | None,
|
||||
) -> Track: ...
|
||||
async def delete(self, track_id: uuid.UUID) -> None: ...
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
"""SQLAlchemy repository adapters implementing the domain ports."""
|
||||
|
||||
from app.infrastructure.db.repositories.artist_repository import SqlAlchemyArtistRepository
|
||||
from app.infrastructure.db.repositories.refresh_token_repository import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
)
|
||||
from app.infrastructure.db.repositories.track_repository import SqlAlchemyTrackRepository
|
||||
from app.infrastructure.db.repositories.user_repository import SqlAlchemyUserRepository
|
||||
|
||||
__all__ = ["SqlAlchemyRefreshTokenRepository", "SqlAlchemyUserRepository"]
|
||||
__all__ = [
|
||||
"SqlAlchemyArtistRepository",
|
||||
"SqlAlchemyRefreshTokenRepository",
|
||||
"SqlAlchemyTrackRepository",
|
||||
"SqlAlchemyUserRepository",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Artist repository — adapter over ``AsyncSession``."""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.entities.track import Artist
|
||||
from app.infrastructure.db.models.artist import ArtistModel
|
||||
|
||||
|
||||
def _to_entity(row: ArtistModel) -> Artist:
|
||||
return Artist(
|
||||
id=row.id,
|
||||
name=row.name,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class SqlAlchemyArtistRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_or_create(self, name: str) -> Artist:
|
||||
row = (
|
||||
await self._session.execute(select(ArtistModel).where(ArtistModel.name == name))
|
||||
).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ArtistModel(name=name)
|
||||
self._session.add(row)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(row)
|
||||
return _to_entity(row)
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Track repository — adapter over ``AsyncSession``."""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.entities.track import Track
|
||||
from app.infrastructure.db.models.track import TrackModel
|
||||
|
||||
|
||||
def _to_entity(row: TrackModel) -> Track:
|
||||
return Track(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
artist_id=row.artist_id,
|
||||
file_path=row.file_path,
|
||||
file_format=row.file_format,
|
||||
file_size=row.file_size,
|
||||
source=row.source,
|
||||
source_id=row.source_id,
|
||||
duration_seconds=row.duration_seconds,
|
||||
metadata_status=row.metadata_status,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class SqlAlchemyTrackRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_by_id(self, track_id: uuid.UUID) -> Track | None:
|
||||
row = await self._session.get(TrackModel, track_id)
|
||||
return _to_entity(row) if row is not None else None
|
||||
|
||||
async def get_by_source(self, source: str, source_id: str) -> Track | None:
|
||||
row = (
|
||||
await self._session.execute(
|
||||
select(TrackModel).where(
|
||||
TrackModel.source == source,
|
||||
TrackModel.source_id == source_id,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return _to_entity(row) if row is not None else None
|
||||
|
||||
async def add(
|
||||
self,
|
||||
*,
|
||||
id: uuid.UUID,
|
||||
title: str,
|
||||
artist_id: uuid.UUID,
|
||||
file_path: str,
|
||||
file_format: str,
|
||||
file_size: int,
|
||||
source: str,
|
||||
source_id: str,
|
||||
metadata_status: str,
|
||||
added_by: uuid.UUID | None,
|
||||
) -> Track:
|
||||
row = TrackModel(
|
||||
id=id,
|
||||
title=title,
|
||||
artist_id=artist_id,
|
||||
file_path=file_path,
|
||||
file_format=file_format,
|
||||
file_size=file_size,
|
||||
source=source,
|
||||
source_id=source_id,
|
||||
metadata_status=metadata_status,
|
||||
added_by=added_by,
|
||||
)
|
||||
self._session.add(row)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(row)
|
||||
return _to_entity(row)
|
||||
|
||||
async def delete(self, track_id: uuid.UUID) -> None:
|
||||
row = await self._session.get(TrackModel, track_id)
|
||||
if row is not None:
|
||||
await self._session.delete(row)
|
||||
await self._session.flush()
|
||||
@@ -0,0 +1 @@
|
||||
"""File storage adapters."""
|
||||
@@ -0,0 +1,86 @@
|
||||
"""LocalFileStorage — stores files on the local filesystem."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
|
||||
from app.domain.entities.storage import ObjectStat
|
||||
from app.domain.errors import StorageError
|
||||
|
||||
_EXT_CONTENT_TYPE: dict[str, str] = {
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
"aac": "audio/aac",
|
||||
"ogg": "audio/ogg",
|
||||
"opus": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"aiff": "audio/aiff",
|
||||
"aif": "audio/aiff",
|
||||
}
|
||||
|
||||
|
||||
class LocalFileStorage:
|
||||
def __init__(self, media_path: Path) -> None:
|
||||
self._media_path = media_path
|
||||
|
||||
async def save_file(self, key: str, src_path: Path) -> int:
|
||||
dest = self._media_path / key
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
part = dest.with_suffix(dest.suffix + ".part")
|
||||
shutil.copyfile(str(src_path), str(part))
|
||||
os.replace(str(part), str(dest))
|
||||
return dest.stat().st_size
|
||||
|
||||
async def open_range(
|
||||
self, key: str, start: int, end: int | None
|
||||
) -> tuple[AsyncIterator[bytes], int]:
|
||||
path = self._media_path / key
|
||||
if not path.exists():
|
||||
raise StorageError(f"Object not found: {key}")
|
||||
total_size = path.stat().st_size
|
||||
|
||||
_start = start
|
||||
_end = end
|
||||
_total_size = total_size
|
||||
_path = path
|
||||
|
||||
async def _iter() -> AsyncGenerator[bytes]:
|
||||
async with await anyio.open_file(_path, "rb") as f:
|
||||
await f.seek(_start)
|
||||
remaining = (_end - _start + 1) if _end is not None else (_total_size - _start)
|
||||
while remaining > 0:
|
||||
chunk: bytes = await f.read(min(65536, remaining))
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
aiter: AsyncIterator[bytes] = _iter()
|
||||
return aiter, total_size
|
||||
|
||||
async def stat(self, key: str) -> ObjectStat:
|
||||
path = self._media_path / key
|
||||
if not path.exists():
|
||||
raise StorageError(f"Object not found: {key}")
|
||||
st = path.stat()
|
||||
ext = path.suffix.lower().lstrip(".")
|
||||
return ObjectStat(size=st.st_size, content_type=_EXT_CONTENT_TYPE.get(ext))
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
return (self._media_path / key).exists()
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
(self._media_path / key).unlink(missing_ok=True)
|
||||
|
||||
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]:
|
||||
return self._as_local_path_cm(key)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]:
|
||||
yield self._media_path / key
|
||||
@@ -0,0 +1,16 @@
|
||||
"""File storage provider — singleton factory."""
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.infrastructure.storage.local import LocalFileStorage
|
||||
|
||||
_storage: LocalFileStorage | None = None
|
||||
|
||||
|
||||
def get_file_storage() -> LocalFileStorage:
|
||||
global _storage
|
||||
if _storage is None:
|
||||
settings = get_settings()
|
||||
if settings.storage_backend == "s3":
|
||||
raise NotImplementedError("S3 storage not yet implemented.")
|
||||
_storage = LocalFileStorage(settings.media_path)
|
||||
return _storage
|
||||
@@ -10,7 +10,7 @@ from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
from app.core.security import Argon2PasswordHasher
|
||||
from app.infrastructure.db import Base, get_engine, session_scope
|
||||
from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
@@ -73,6 +73,7 @@ async def api() -> AsyncIterator[AsyncClient]:
|
||||
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await dispose_engine()
|
||||
|
||||
|
||||
async def _login(api: AsyncClient, username: str, password: str) -> tuple[str, str]:
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Unit tests for LocalFileStorage."""
|
||||
|
||||
import pytest
|
||||
from app.infrastructure.storage.local import LocalFileStorage
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_save_and_stat(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.mp3"
|
||||
src.write_bytes(b"test audio data")
|
||||
|
||||
size = await storage.save_file("tracks/te/test.mp3", src)
|
||||
assert size == 15
|
||||
|
||||
stat = await storage.stat("tracks/te/test.mp3")
|
||||
assert stat.size == 15
|
||||
assert stat.content_type == "audio/mpeg"
|
||||
|
||||
|
||||
async def test_save_creates_parent_dirs(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.flac"
|
||||
src.write_bytes(b"x")
|
||||
|
||||
await storage.save_file("tracks/ab/cdef.flac", src)
|
||||
assert (tmp_path / "tracks" / "ab" / "cdef.flac").exists()
|
||||
|
||||
|
||||
async def test_open_range_full(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
data = b"hello world" * 100
|
||||
src = tmp_path / "src.flac"
|
||||
src.write_bytes(data)
|
||||
await storage.save_file("tracks/he/hello.flac", src)
|
||||
|
||||
stream, total = await storage.open_range("tracks/he/hello.flac", 0, None)
|
||||
result = b"".join([chunk async for chunk in stream])
|
||||
assert result == data
|
||||
assert total == len(data)
|
||||
|
||||
|
||||
async def test_open_range_partial(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
data = b"0123456789"
|
||||
src = tmp_path / "src.mp3"
|
||||
src.write_bytes(data)
|
||||
await storage.save_file("tracks/sr/src.mp3", src)
|
||||
|
||||
stream, total = await storage.open_range("tracks/sr/src.mp3", 3, 7)
|
||||
result = b"".join([chunk async for chunk in stream])
|
||||
assert result == b"34567"
|
||||
assert total == 10
|
||||
|
||||
|
||||
async def test_open_range_from_offset_to_end(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
data = b"abcdefghij"
|
||||
src = tmp_path / "src.wav"
|
||||
src.write_bytes(data)
|
||||
await storage.save_file("tracks/sr/src.wav", src)
|
||||
|
||||
stream, total = await storage.open_range("tracks/sr/src.wav", 5, None)
|
||||
result = b"".join([chunk async for chunk in stream])
|
||||
assert result == b"fghij"
|
||||
assert total == 10
|
||||
|
||||
|
||||
async def test_exists_and_delete(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.ogg"
|
||||
src.write_bytes(b"ogg data")
|
||||
await storage.save_file("tracks/sr/src.ogg", src)
|
||||
|
||||
assert await storage.exists("tracks/sr/src.ogg") is True
|
||||
await storage.delete("tracks/sr/src.ogg")
|
||||
assert await storage.exists("tracks/sr/src.ogg") is False
|
||||
|
||||
|
||||
async def test_delete_missing_is_noop(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
await storage.delete("tracks/no/nope.mp3")
|
||||
|
||||
|
||||
async def test_as_local_path(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.mp3"
|
||||
src.write_bytes(b"local bytes")
|
||||
await storage.save_file("tracks/lo/local.mp3", src)
|
||||
|
||||
async with storage.as_local_path("tracks/lo/local.mp3") as path:
|
||||
assert path.read_bytes() == b"local bytes"
|
||||
|
||||
|
||||
async def test_stat_unknown_extension(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.xyz"
|
||||
src.write_bytes(b"mystery")
|
||||
await storage.save_file("tracks/my/mystery.xyz", src)
|
||||
|
||||
stat = await storage.stat("tracks/my/mystery.xyz")
|
||||
assert stat.size == 7
|
||||
assert stat.content_type is None
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Integration tests for upload and streaming endpoints.
|
||||
|
||||
Requires a reachable Postgres; skips otherwise.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from app.core.config import get_settings
|
||||
from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
)
|
||||
from asgi_lifespan import LifespanManager
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
_db_reachable_cache: bool | None = None
|
||||
|
||||
|
||||
async def _db_reachable() -> bool:
|
||||
global _db_reachable_cache
|
||||
if _db_reachable_cache is not None:
|
||||
return _db_reachable_cache
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(3):
|
||||
async with get_engine().connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
_db_reachable_cache = True
|
||||
except Exception:
|
||||
_db_reachable_cache = False
|
||||
return _db_reachable_cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def api(tmp_path: Path) -> AsyncIterator[AsyncClient]:
|
||||
if not await _db_reachable():
|
||||
pytest.skip("Postgres not reachable — integration test skipped.")
|
||||
|
||||
os.environ["MEDIA_PATH"] = str(tmp_path)
|
||||
get_settings.cache_clear()
|
||||
|
||||
# Also reset the file storage singleton so it picks up the new media_path
|
||||
import app.infrastructure.storage.provider as _storage_provider
|
||||
|
||||
_storage_provider._storage = None
|
||||
|
||||
try:
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
from app.application.user_service import UserService
|
||||
from app.core.security import Argon2PasswordHasher
|
||||
|
||||
async with session_scope() as session:
|
||||
await UserService(
|
||||
users=SqlAlchemyUserRepository(session),
|
||||
refresh_tokens=SqlAlchemyRefreshTokenRepository(session),
|
||||
hasher=Argon2PasswordHasher(),
|
||||
).create_user(username="testuser", password="testpass1", is_superuser=False)
|
||||
|
||||
from app.main import create_app
|
||||
|
||||
app = create_app()
|
||||
async with LifespanManager(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await dispose_engine()
|
||||
finally:
|
||||
_storage_provider._storage = None
|
||||
os.environ.pop("MEDIA_PATH", None)
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
async def _login(api: AsyncClient) -> str:
|
||||
resp = await api.post(
|
||||
"/api/v1/auth/login", json={"username": "testuser", "password": "testpass1"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return str(resp.json()["access_token"])
|
||||
|
||||
|
||||
async def test_upload_creates_track(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"fake mp3 bytes" * 100
|
||||
|
||||
resp = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("song.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["already_exists"] is False
|
||||
assert body["title"] == "song"
|
||||
assert "track_id" in body
|
||||
|
||||
|
||||
async def test_upload_dedup(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"same content" * 50
|
||||
|
||||
first = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("a.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert first.status_code == 200
|
||||
assert first.json()["already_exists"] is False
|
||||
|
||||
second = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("b.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert second.status_code == 200
|
||||
assert second.json()["already_exists"] is True
|
||||
assert second.json()["track_id"] == first.json()["track_id"]
|
||||
|
||||
|
||||
async def test_stream_full(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"audio data for streaming" * 10
|
||||
|
||||
up = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("track.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert up.status_code == 200
|
||||
track_id = up.json()["track_id"]
|
||||
|
||||
resp = await api.get(f"/api/v1/stream/{track_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == audio
|
||||
assert resp.headers["content-type"].startswith("audio/mpeg")
|
||||
assert "accept-ranges" in resp.headers
|
||||
|
||||
|
||||
async def test_stream_range(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"0123456789" * 10
|
||||
|
||||
up = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("range.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert up.status_code == 200
|
||||
track_id = up.json()["track_id"]
|
||||
|
||||
resp = await api.get(
|
||||
f"/api/v1/stream/{track_id}",
|
||||
headers={"Range": "bytes=0-9"},
|
||||
)
|
||||
assert resp.status_code == 206
|
||||
assert resp.content == b"0123456789"
|
||||
assert resp.headers["content-range"] == f"bytes 0-9/{len(audio)}"
|
||||
assert resp.headers["content-length"] == "10"
|
||||
|
||||
|
||||
async def test_stream_not_found(api: AsyncClient) -> None:
|
||||
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_upload_requires_auth(api: AsyncClient) -> None:
|
||||
resp = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("x.mp3", b"data", "audio/mpeg")},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
Reference in New Issue
Block a user