feat: local storage logic & endpoints

This commit is contained in:
Senko-san
2026-06-07 15:34:06 +03:00
parent dfd512a13f
commit 81ea93c371
23 changed files with 945 additions and 18 deletions
+31 -1
View File
@@ -15,17 +15,22 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.application.auth_service import AuthService
from app.application.streaming_service import StreamingService
from app.application.upload_service import UploadService
from app.application.user_service import UserService
from app.core.config import get_settings
from app.core.security import Argon2PasswordHasher, JwtTokenService
from app.domain.entities import User
from app.domain.errors import AuthenticationError, PermissionDeniedError
from app.domain.ports import PasswordHasher, TokenService
from app.domain.ports import FileStorage, PasswordHasher, TokenService
from app.infrastructure.db import get_sessionmaker
from app.infrastructure.db.repositories import (
SqlAlchemyArtistRepository,
SqlAlchemyRefreshTokenRepository,
SqlAlchemyTrackRepository,
SqlAlchemyUserRepository,
)
from app.infrastructure.storage.provider import get_file_storage
async def get_session() -> AsyncIterator[AsyncSession]:
@@ -77,6 +82,31 @@ AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
# -- file storage (process-cached) ---------------------------------------------
FileStorageDep = Annotated[FileStorage, Depends(get_file_storage)]
def get_upload_service(session: SessionDep, storage: FileStorageDep) -> UploadService:
settings = get_settings()
return UploadService(
tracks=SqlAlchemyTrackRepository(session),
artists=SqlAlchemyArtistRepository(session),
storage=storage,
tmp_dir=settings.upload_tmp_dir,
)
def get_streaming_service(session: SessionDep, storage: FileStorageDep) -> StreamingService:
return StreamingService(
tracks=SqlAlchemyTrackRepository(session),
storage=storage,
)
UploadServiceDep = Annotated[UploadService, Depends(get_upload_service)]
StreamingServiceDep = Annotated[StreamingService, Depends(get_streaming_service)]
# -- current user / authorization ----------------------------------------------
# auto_error=False: we raise domain AuthenticationError (mapped to 401) so the
# error envelope stays consistent with the rest of the API.
+11
View File
@@ -12,6 +12,8 @@ from app.domain.errors import (
DomainError,
NotFoundError,
PermissionDeniedError,
RangeNotSatisfiableError,
StorageError,
ValidationError,
)
@@ -25,6 +27,7 @@ _STATUS_BY_ERROR: dict[type[DomainError], int] = {
AuthenticationError: status.HTTP_401_UNAUTHORIZED,
PermissionDeniedError: status.HTTP_403_FORBIDDEN,
DependencyUnavailableError: status.HTTP_503_SERVICE_UNAVAILABLE,
StorageError: status.HTTP_500_INTERNAL_SERVER_ERROR,
}
@@ -33,6 +36,14 @@ def _error_body(code: str, message: str) -> dict[str, dict[str, str]]:
def register_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(RangeNotSatisfiableError)
async def _handle_range_error(_request: Request, exc: RangeNotSatisfiableError) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
content=_error_body(exc.code, exc.message),
headers={"Content-Range": f"bytes */{exc.total_size}"},
)
@app.exception_handler(DomainError)
async def _handle_domain_error(_request: Request, exc: DomainError) -> JSONResponse:
http_status = _STATUS_BY_ERROR.get(type(exc), status.HTTP_400_BAD_REQUEST)
+11
View File
@@ -0,0 +1,11 @@
"""Schemas for upload responses."""
import uuid
from pydantic import BaseModel
class UploadResponse(BaseModel):
track_id: uuid.UUID
title: str
already_exists: bool
+27 -9
View File
@@ -1,20 +1,38 @@
"""Audio streaming endpoints: direct stream and HLS."""
"""Audio streaming endpoint direct stream with Range support."""
import uuid
from typing import Any
from typing import Annotated
from fastapi import APIRouter
from fastapi import APIRouter, Header
from fastapi.responses import StreamingResponse
from app.api.deps import StreamingServiceDep
router = APIRouter(prefix="/stream", tags=["streaming"])
@router.get("/{track_id}")
async def stream_track(track_id: uuid.UUID) -> Any: ...
async def stream_track(
track_id: uuid.UUID,
service: StreamingServiceDep,
range_header: Annotated[str | None, Header(alias="Range")] = None,
) -> StreamingResponse:
result = await service.open_stream(track_id, range_header)
headers = {
"Accept-Ranges": "bytes",
"Content-Length": str(result.content_length),
}
@router.get("/{track_id}/hls/playlist.m3u8")
async def hls_playlist(track_id: uuid.UUID) -> Any: ...
if result.is_partial:
headers["Content-Range"] = f"bytes {result.start}-{result.end}/{result.total_size}"
status_code = 206
else:
status_code = 200
@router.get("/{track_id}/hls/{segment}")
async def hls_segment(track_id: uuid.UUID, segment: str) -> Any: ...
return StreamingResponse(
result.stream,
status_code=status_code,
headers=headers,
media_type=result.content_type,
)
+17 -4
View File
@@ -1,11 +1,24 @@
"""Local file upload endpoint."""
from typing import Any
from typing import Annotated
from fastapi import APIRouter
from fastapi import APIRouter, File, UploadFile
from app.api.deps import CurrentUser, UploadServiceDep
from app.api.schemas.upload import UploadResponse
router = APIRouter(prefix="/upload", tags=["upload"])
@router.post("")
async def upload_file() -> Any: ...
@router.post("", response_model=UploadResponse)
async def upload_file(
file: Annotated[UploadFile, File()],
current_user: CurrentUser,
service: UploadServiceDep,
) -> UploadResponse:
result = await service.handle_upload(upload=file, user=current_user)
return UploadResponse(
track_id=result.track_id,
title=result.title,
already_exists=result.already_exists,
)
+97
View File
@@ -0,0 +1,97 @@
"""StreamingService — resolves a track and opens a byte-range stream."""
import re
import uuid
from collections.abc import AsyncIterator
from dataclasses import dataclass
from app.domain.errors import NotFoundError, RangeNotSatisfiableError
from app.domain.ports import FileStorage, TrackRepository
_FORMAT_CONTENT_TYPE: dict[str, str] = {
"mp3": "audio/mpeg",
"flac": "audio/flac",
"m4a": "audio/mp4",
"aac": "audio/aac",
"ogg": "audio/ogg",
"opus": "audio/ogg",
"wav": "audio/wav",
"wma": "audio/x-ms-wma",
"aiff": "audio/aiff",
"aif": "audio/aiff",
}
_RANGE_RE = re.compile(r"bytes=(\d+)-(\d*)")
@dataclass
class StreamResult:
stream: AsyncIterator[bytes]
total_size: int
content_length: int
content_type: str
start: int
end: int
is_partial: bool
def _parse_range(header: str | None, total_size: int) -> tuple[int, int | None, bool]:
"""Return (start, end, is_partial). Raises RangeNotSatisfiableError on invalid range."""
if header is None:
return 0, None, False
m = _RANGE_RE.fullmatch(header.strip())
if not m:
return 0, None, False # malformed → treat as absent per RFC 7233
start = int(m.group(1))
end: int | None = int(m.group(2)) if m.group(2) else None
if start >= total_size:
raise RangeNotSatisfiableError(total_size)
if end is not None:
if end >= total_size:
end = total_size - 1
if end < start:
raise RangeNotSatisfiableError(total_size)
return start, end, True
class StreamingService:
def __init__(self, tracks: TrackRepository, storage: FileStorage) -> None:
self._tracks = tracks
self._storage = storage
async def open_stream(
self,
track_id: uuid.UUID,
range_header: str | None,
) -> StreamResult:
track = await self._tracks.get_by_id(track_id)
if track is None:
raise NotFoundError("Track not found.")
stat = await self._storage.stat(track.file_path)
total_size = stat.size
content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get(
track.file_format.lower(), "application/octet-stream"
)
start, end, is_partial = _parse_range(range_header, total_size)
stream, _ = await self._storage.open_range(track.file_path, start, end)
actual_end = end if end is not None else total_size - 1
content_length = actual_end - start + 1
return StreamResult(
stream=stream,
total_size=total_size,
content_length=content_length,
content_type=content_type,
start=start,
end=actual_end,
is_partial=is_partial,
)
+116
View File
@@ -0,0 +1,116 @@
"""UploadService — handles user file uploads."""
import contextlib
import hashlib
import os
import tempfile
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol
import anyio
from app.domain.entities.user import User
from app.domain.ports import ArtistRepository, FileStorage, TrackRepository
class UploadFileProtocol(Protocol):
filename: str | None
async def read(self, size: int = -1) -> bytes: ...
@dataclass(frozen=True)
class UploadResult:
track_id: uuid.UUID
title: str
already_exists: bool
async def _stream_to_temp(upload: UploadFileProtocol, dest: Path) -> tuple[str, int]:
h = hashlib.sha256()
size = 0
async with await anyio.open_file(dest, "wb") as out:
while True:
chunk = await upload.read(65536)
if not chunk:
break
h.update(chunk)
await out.write(chunk)
size += len(chunk)
return h.hexdigest(), size
class UploadService:
def __init__(
self,
tracks: TrackRepository,
artists: ArtistRepository,
storage: FileStorage,
tmp_dir: Path | None = None,
) -> None:
self._tracks = tracks
self._artists = artists
self._storage = storage
self._tmp_dir = tmp_dir
async def handle_upload(
self,
*,
upload: UploadFileProtocol,
user: User,
) -> UploadResult:
filename = upload.filename or "unknown"
ext = Path(filename).suffix.lower().lstrip(".") or "bin"
title = Path(filename).stem or "Unknown"
fd, tmp_str = tempfile.mkstemp(
suffix=f".{ext}",
dir=str(self._tmp_dir) if self._tmp_dir else None,
)
tmp_path = Path(tmp_str)
try:
os.close(fd)
sha256_hex, file_size = await _stream_to_temp(upload, tmp_path)
existing = await self._tracks.get_by_source("upload", sha256_hex)
if existing is not None:
return UploadResult(
track_id=existing.id,
title=existing.title,
already_exists=True,
)
track_id = uuid.uuid4()
key = f"tracks/{str(track_id)[:2]}/{track_id}.{ext}"
await self._storage.save_file(key, tmp_path)
try:
artist = await self._artists.get_or_create("Unknown Artist")
track = await self._tracks.add(
id=track_id,
title=title,
artist_id=artist.id,
file_path=key,
file_format=ext,
file_size=file_size,
source="upload",
source_id=sha256_hex,
metadata_status="pending",
added_by=user.id,
)
except Exception:
with contextlib.suppress(Exception):
await self._storage.delete(key)
raise
# TODO(1D): enqueue metadata enrichment task
return UploadResult(
track_id=track.id,
title=track.title,
already_exists=False,
)
finally:
await anyio.Path(tmp_path).unlink(missing_ok=True)
+9
View File
@@ -49,6 +49,15 @@ class Settings(BaseSettings):
media_path: Path = Path("/data/media")
transcode_cache_path: Path = Path("/data/transcode-cache")
max_parallel_downloads: int = 2
storage_backend: Literal["local", "s3"] = "local"
upload_tmp_dir: Path | None = None
# -- S3 storage (deferred; set storage_backend="s3" to use) ----------
s3_endpoint_url: str | None = None
s3_bucket: str | None = None
s3_region: str | None = None
s3_access_key: SecretStr | None = None
s3_secret_key: SecretStr | None = None
# -- external services (all optional; graceful degradation) ----------
ml_service_url: str | None = None
+12
View File
@@ -0,0 +1,12 @@
"""File hashing utilities."""
import hashlib
from pathlib import Path
def sha256_of_file(path: Path) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(65536), b""):
h.update(chunk)
return h.hexdigest()
+3 -1
View File
@@ -1,5 +1,7 @@
"""Domain entities and value objects — pure, framework-free."""
from app.domain.entities.storage import ObjectStat
from app.domain.entities.track import Artist, Track
from app.domain.entities.user import Credentials, User
__all__ = ["Credentials", "User"]
__all__ = ["Artist", "Credentials", "ObjectStat", "Track", "User"]
+9
View File
@@ -0,0 +1,9 @@
"""Value objects for file storage."""
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ObjectStat:
size: int
content_type: str | None
+29
View File
@@ -0,0 +1,29 @@
"""Track and Artist domain entities."""
import datetime as dt
import uuid
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class Artist:
id: uuid.UUID
name: str
created_at: dt.datetime
updated_at: dt.datetime
@dataclass(frozen=True, slots=True)
class Track:
id: uuid.UUID
title: str
artist_id: uuid.UUID
file_path: str
file_format: str
file_size: int
source: str
source_id: str
duration_seconds: int | None
metadata_status: str
created_at: dt.datetime
updated_at: dt.datetime
+16
View File
@@ -61,3 +61,19 @@ class DependencyUnavailableError(DomainError):
"""
code = "dependency_unavailable"
class StorageError(DomainError):
"""File storage operation failed."""
code = "storage_error"
class RangeNotSatisfiableError(DomainError):
"""Requested byte range cannot be satisfied."""
code = "range_not_satisfiable"
def __init__(self, total_size: int) -> None:
super().__init__("Requested range is not satisfiable.")
self.total_size = total_size
+40 -1
View File
@@ -7,9 +7,13 @@ are bound to these ports at the composition root (``app.api.deps``).
import datetime as dt
import uuid
from collections.abc import AsyncIterator
from contextlib import AbstractAsyncContextManager
from pathlib import Path
from typing import Protocol
from app.domain.entities import Credentials, User
from app.domain.entities import Credentials, ObjectStat, User
from app.domain.entities.track import Artist, Track
from app.domain.tokens import IssuedToken, TokenClaims, TokenType
@@ -56,3 +60,38 @@ class TokenService(Protocol):
"""Verify signature + expiry and return claims. Raises
:class:`~app.domain.errors.AuthenticationError` on any failure."""
...
class FileStorage(Protocol):
async def save_file(self, key: str, src_path: Path) -> int: ...
async def open_range(
self, key: str, start: int, end: int | None
) -> tuple[AsyncIterator[bytes], int]: ...
async def stat(self, key: str) -> ObjectStat: ...
async def exists(self, key: str) -> bool: ...
async def delete(self, key: str) -> None: ...
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]: ...
class ArtistRepository(Protocol):
async def get_or_create(self, name: str) -> Artist: ...
class TrackRepository(Protocol):
async def get_by_id(self, track_id: uuid.UUID) -> Track | None: ...
async def get_by_source(self, source: str, source_id: str) -> Track | None: ...
async def add(
self,
*,
id: uuid.UUID,
title: str,
artist_id: uuid.UUID,
file_path: str,
file_format: str,
file_size: int,
source: str,
source_id: str,
metadata_status: str,
added_by: uuid.UUID | None,
) -> Track: ...
async def delete(self, track_id: uuid.UUID) -> None: ...
@@ -1,8 +1,15 @@
"""SQLAlchemy repository adapters implementing the domain ports."""
from app.infrastructure.db.repositories.artist_repository import SqlAlchemyArtistRepository
from app.infrastructure.db.repositories.refresh_token_repository import (
SqlAlchemyRefreshTokenRepository,
)
from app.infrastructure.db.repositories.track_repository import SqlAlchemyTrackRepository
from app.infrastructure.db.repositories.user_repository import SqlAlchemyUserRepository
__all__ = ["SqlAlchemyRefreshTokenRepository", "SqlAlchemyUserRepository"]
__all__ = [
"SqlAlchemyArtistRepository",
"SqlAlchemyRefreshTokenRepository",
"SqlAlchemyTrackRepository",
"SqlAlchemyUserRepository",
]
@@ -0,0 +1,32 @@
"""Artist repository — adapter over ``AsyncSession``."""
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.domain.entities.track import Artist
from app.infrastructure.db.models.artist import ArtistModel
def _to_entity(row: ArtistModel) -> Artist:
return Artist(
id=row.id,
name=row.name,
created_at=row.created_at,
updated_at=row.updated_at,
)
class SqlAlchemyArtistRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_or_create(self, name: str) -> Artist:
row = (
await self._session.execute(select(ArtistModel).where(ArtistModel.name == name))
).scalar_one_or_none()
if row is None:
row = ArtistModel(name=name)
self._session.add(row)
await self._session.flush()
await self._session.refresh(row)
return _to_entity(row)
@@ -0,0 +1,83 @@
"""Track repository — adapter over ``AsyncSession``."""
import uuid
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.domain.entities.track import Track
from app.infrastructure.db.models.track import TrackModel
def _to_entity(row: TrackModel) -> Track:
return Track(
id=row.id,
title=row.title,
artist_id=row.artist_id,
file_path=row.file_path,
file_format=row.file_format,
file_size=row.file_size,
source=row.source,
source_id=row.source_id,
duration_seconds=row.duration_seconds,
metadata_status=row.metadata_status,
created_at=row.created_at,
updated_at=row.updated_at,
)
class SqlAlchemyTrackRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_by_id(self, track_id: uuid.UUID) -> Track | None:
row = await self._session.get(TrackModel, track_id)
return _to_entity(row) if row is not None else None
async def get_by_source(self, source: str, source_id: str) -> Track | None:
row = (
await self._session.execute(
select(TrackModel).where(
TrackModel.source == source,
TrackModel.source_id == source_id,
)
)
).scalar_one_or_none()
return _to_entity(row) if row is not None else None
async def add(
self,
*,
id: uuid.UUID,
title: str,
artist_id: uuid.UUID,
file_path: str,
file_format: str,
file_size: int,
source: str,
source_id: str,
metadata_status: str,
added_by: uuid.UUID | None,
) -> Track:
row = TrackModel(
id=id,
title=title,
artist_id=artist_id,
file_path=file_path,
file_format=file_format,
file_size=file_size,
source=source,
source_id=source_id,
metadata_status=metadata_status,
added_by=added_by,
)
self._session.add(row)
await self._session.flush()
await self._session.refresh(row)
return _to_entity(row)
async def delete(self, track_id: uuid.UUID) -> None:
row = await self._session.get(TrackModel, track_id)
if row is not None:
await self._session.delete(row)
await self._session.flush()
+1
View File
@@ -0,0 +1 @@
"""File storage adapters."""
+86
View File
@@ -0,0 +1,86 @@
"""LocalFileStorage — stores files on the local filesystem."""
import os
import shutil
from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from pathlib import Path
import anyio
from app.domain.entities.storage import ObjectStat
from app.domain.errors import StorageError
_EXT_CONTENT_TYPE: dict[str, str] = {
"mp3": "audio/mpeg",
"flac": "audio/flac",
"m4a": "audio/mp4",
"aac": "audio/aac",
"ogg": "audio/ogg",
"opus": "audio/ogg",
"wav": "audio/wav",
"wma": "audio/x-ms-wma",
"aiff": "audio/aiff",
"aif": "audio/aiff",
}
class LocalFileStorage:
def __init__(self, media_path: Path) -> None:
self._media_path = media_path
async def save_file(self, key: str, src_path: Path) -> int:
dest = self._media_path / key
dest.parent.mkdir(parents=True, exist_ok=True)
part = dest.with_suffix(dest.suffix + ".part")
shutil.copyfile(str(src_path), str(part))
os.replace(str(part), str(dest))
return dest.stat().st_size
async def open_range(
self, key: str, start: int, end: int | None
) -> tuple[AsyncIterator[bytes], int]:
path = self._media_path / key
if not path.exists():
raise StorageError(f"Object not found: {key}")
total_size = path.stat().st_size
_start = start
_end = end
_total_size = total_size
_path = path
async def _iter() -> AsyncGenerator[bytes]:
async with await anyio.open_file(_path, "rb") as f:
await f.seek(_start)
remaining = (_end - _start + 1) if _end is not None else (_total_size - _start)
while remaining > 0:
chunk: bytes = await f.read(min(65536, remaining))
if not chunk:
break
yield chunk
remaining -= len(chunk)
aiter: AsyncIterator[bytes] = _iter()
return aiter, total_size
async def stat(self, key: str) -> ObjectStat:
path = self._media_path / key
if not path.exists():
raise StorageError(f"Object not found: {key}")
st = path.stat()
ext = path.suffix.lower().lstrip(".")
return ObjectStat(size=st.st_size, content_type=_EXT_CONTENT_TYPE.get(ext))
async def exists(self, key: str) -> bool:
return (self._media_path / key).exists()
async def delete(self, key: str) -> None:
(self._media_path / key).unlink(missing_ok=True)
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]:
return self._as_local_path_cm(key)
@asynccontextmanager
async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]:
yield self._media_path / key
+16
View File
@@ -0,0 +1,16 @@
"""File storage provider — singleton factory."""
from app.core.config import get_settings
from app.infrastructure.storage.local import LocalFileStorage
_storage: LocalFileStorage | None = None
def get_file_storage() -> LocalFileStorage:
global _storage
if _storage is None:
settings = get_settings()
if settings.storage_backend == "s3":
raise NotImplementedError("S3 storage not yet implemented.")
_storage = LocalFileStorage(settings.media_path)
return _storage
+2 -1
View File
@@ -10,7 +10,7 @@ from collections.abc import AsyncIterator
import pytest
from app.core.security import Argon2PasswordHasher
from app.infrastructure.db import Base, get_engine, session_scope
from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope
from app.infrastructure.db.repositories import (
SqlAlchemyRefreshTokenRepository,
SqlAlchemyUserRepository,
@@ -73,6 +73,7 @@ async def api() -> AsyncIterator[AsyncClient]:
async with get_engine().begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await dispose_engine()
async def _login(api: AsyncClient, username: str, password: str) -> tuple[str, str]:
+104
View File
@@ -0,0 +1,104 @@
"""Unit tests for LocalFileStorage."""
import pytest
from app.infrastructure.storage.local import LocalFileStorage
pytestmark = pytest.mark.asyncio
async def test_save_and_stat(tmp_path):
storage = LocalFileStorage(tmp_path)
src = tmp_path / "src.mp3"
src.write_bytes(b"test audio data")
size = await storage.save_file("tracks/te/test.mp3", src)
assert size == 15
stat = await storage.stat("tracks/te/test.mp3")
assert stat.size == 15
assert stat.content_type == "audio/mpeg"
async def test_save_creates_parent_dirs(tmp_path):
storage = LocalFileStorage(tmp_path)
src = tmp_path / "src.flac"
src.write_bytes(b"x")
await storage.save_file("tracks/ab/cdef.flac", src)
assert (tmp_path / "tracks" / "ab" / "cdef.flac").exists()
async def test_open_range_full(tmp_path):
storage = LocalFileStorage(tmp_path)
data = b"hello world" * 100
src = tmp_path / "src.flac"
src.write_bytes(data)
await storage.save_file("tracks/he/hello.flac", src)
stream, total = await storage.open_range("tracks/he/hello.flac", 0, None)
result = b"".join([chunk async for chunk in stream])
assert result == data
assert total == len(data)
async def test_open_range_partial(tmp_path):
storage = LocalFileStorage(tmp_path)
data = b"0123456789"
src = tmp_path / "src.mp3"
src.write_bytes(data)
await storage.save_file("tracks/sr/src.mp3", src)
stream, total = await storage.open_range("tracks/sr/src.mp3", 3, 7)
result = b"".join([chunk async for chunk in stream])
assert result == b"34567"
assert total == 10
async def test_open_range_from_offset_to_end(tmp_path):
storage = LocalFileStorage(tmp_path)
data = b"abcdefghij"
src = tmp_path / "src.wav"
src.write_bytes(data)
await storage.save_file("tracks/sr/src.wav", src)
stream, total = await storage.open_range("tracks/sr/src.wav", 5, None)
result = b"".join([chunk async for chunk in stream])
assert result == b"fghij"
assert total == 10
async def test_exists_and_delete(tmp_path):
storage = LocalFileStorage(tmp_path)
src = tmp_path / "src.ogg"
src.write_bytes(b"ogg data")
await storage.save_file("tracks/sr/src.ogg", src)
assert await storage.exists("tracks/sr/src.ogg") is True
await storage.delete("tracks/sr/src.ogg")
assert await storage.exists("tracks/sr/src.ogg") is False
async def test_delete_missing_is_noop(tmp_path):
storage = LocalFileStorage(tmp_path)
await storage.delete("tracks/no/nope.mp3")
async def test_as_local_path(tmp_path):
storage = LocalFileStorage(tmp_path)
src = tmp_path / "src.mp3"
src.write_bytes(b"local bytes")
await storage.save_file("tracks/lo/local.mp3", src)
async with storage.as_local_path("tracks/lo/local.mp3") as path:
assert path.read_bytes() == b"local bytes"
async def test_stat_unknown_extension(tmp_path):
storage = LocalFileStorage(tmp_path)
src = tmp_path / "src.xyz"
src.write_bytes(b"mystery")
await storage.save_file("tracks/my/mystery.xyz", src)
stat = await storage.stat("tracks/my/mystery.xyz")
assert stat.size == 7
assert stat.content_type is None
+185
View File
@@ -0,0 +1,185 @@
"""Integration tests for upload and streaming endpoints.
Requires a reachable Postgres; skips otherwise.
"""
import asyncio
import os
from collections.abc import AsyncIterator
from pathlib import Path
import pytest
from app.core.config import get_settings
from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope
from app.infrastructure.db.repositories import (
SqlAlchemyRefreshTokenRepository,
SqlAlchemyUserRepository,
)
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.asyncio
_db_reachable_cache: bool | None = None
async def _db_reachable() -> bool:
global _db_reachable_cache
if _db_reachable_cache is not None:
return _db_reachable_cache
from sqlalchemy import text
try:
async with asyncio.timeout(3):
async with get_engine().connect() as conn:
await conn.execute(text("SELECT 1"))
_db_reachable_cache = True
except Exception:
_db_reachable_cache = False
return _db_reachable_cache
@pytest.fixture
async def api(tmp_path: Path) -> AsyncIterator[AsyncClient]:
if not await _db_reachable():
pytest.skip("Postgres not reachable — integration test skipped.")
os.environ["MEDIA_PATH"] = str(tmp_path)
get_settings.cache_clear()
# Also reset the file storage singleton so it picks up the new media_path
import app.infrastructure.storage.provider as _storage_provider
_storage_provider._storage = None
try:
async with get_engine().begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
from app.application.user_service import UserService
from app.core.security import Argon2PasswordHasher
async with session_scope() as session:
await UserService(
users=SqlAlchemyUserRepository(session),
refresh_tokens=SqlAlchemyRefreshTokenRepository(session),
hasher=Argon2PasswordHasher(),
).create_user(username="testuser", password="testpass1", is_superuser=False)
from app.main import create_app
app = create_app()
async with LifespanManager(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async with get_engine().begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await dispose_engine()
finally:
_storage_provider._storage = None
os.environ.pop("MEDIA_PATH", None)
get_settings.cache_clear()
async def _login(api: AsyncClient) -> str:
resp = await api.post(
"/api/v1/auth/login", json={"username": "testuser", "password": "testpass1"}
)
assert resp.status_code == 200
return str(resp.json()["access_token"])
async def test_upload_creates_track(api: AsyncClient) -> None:
token = await _login(api)
audio = b"fake mp3 bytes" * 100
resp = await api.post(
"/api/v1/upload",
files={"file": ("song.mp3", audio, "audio/mpeg")},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["already_exists"] is False
assert body["title"] == "song"
assert "track_id" in body
async def test_upload_dedup(api: AsyncClient) -> None:
token = await _login(api)
audio = b"same content" * 50
first = await api.post(
"/api/v1/upload",
files={"file": ("a.mp3", audio, "audio/mpeg")},
headers={"Authorization": f"Bearer {token}"},
)
assert first.status_code == 200
assert first.json()["already_exists"] is False
second = await api.post(
"/api/v1/upload",
files={"file": ("b.mp3", audio, "audio/mpeg")},
headers={"Authorization": f"Bearer {token}"},
)
assert second.status_code == 200
assert second.json()["already_exists"] is True
assert second.json()["track_id"] == first.json()["track_id"]
async def test_stream_full(api: AsyncClient) -> None:
token = await _login(api)
audio = b"audio data for streaming" * 10
up = await api.post(
"/api/v1/upload",
files={"file": ("track.mp3", audio, "audio/mpeg")},
headers={"Authorization": f"Bearer {token}"},
)
assert up.status_code == 200
track_id = up.json()["track_id"]
resp = await api.get(f"/api/v1/stream/{track_id}")
assert resp.status_code == 200
assert resp.content == audio
assert resp.headers["content-type"].startswith("audio/mpeg")
assert "accept-ranges" in resp.headers
async def test_stream_range(api: AsyncClient) -> None:
token = await _login(api)
audio = b"0123456789" * 10
up = await api.post(
"/api/v1/upload",
files={"file": ("range.mp3", audio, "audio/mpeg")},
headers={"Authorization": f"Bearer {token}"},
)
assert up.status_code == 200
track_id = up.json()["track_id"]
resp = await api.get(
f"/api/v1/stream/{track_id}",
headers={"Range": "bytes=0-9"},
)
assert resp.status_code == 206
assert resp.content == b"0123456789"
assert resp.headers["content-range"] == f"bytes 0-9/{len(audio)}"
assert resp.headers["content-length"] == "10"
async def test_stream_not_found(api: AsyncClient) -> None:
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
assert resp.status_code == 404
async def test_upload_requires_auth(api: AsyncClient) -> None:
resp = await api.post(
"/api/v1/upload",
files={"file": ("x.mp3", b"data", "audio/mpeg")},
)
assert resp.status_code == 401