Compare commits
2 Commits
a8348e145a
...
4ade6939b6
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ade6939b6 | |||
| 5c5df5d3cc |
@@ -0,0 +1,25 @@
|
|||||||
|
"""rename track file_path to storage_uri
|
||||||
|
|
||||||
|
Revision ID: 20260608_storage_uri
|
||||||
|
Revises: e670d6c41d0c
|
||||||
|
Create Date: 2026-06-08 11:32:00.000000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "20260608_storage_uri"
|
||||||
|
down_revision: str | None = "e670d6c41d0c"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column("tracks", "file_path", new_column_name="storage_uri")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column("tracks", "storage_uri", new_column_name="file_path")
|
||||||
@@ -167,3 +167,23 @@ async def get_current_superuser(user: CurrentUser) -> User:
|
|||||||
|
|
||||||
|
|
||||||
SuperUser = Annotated[User, Depends(get_current_superuser)]
|
SuperUser = Annotated[User, Depends(get_current_superuser)]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_streaming_user(
|
||||||
|
auth: AuthServiceDep,
|
||||||
|
credentials: BearerDep,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
"""Authenticate a stream request.
|
||||||
|
|
||||||
|
The browser ``<audio>`` element cannot send an ``Authorization`` header, so
|
||||||
|
the access token is accepted as a ``?token=`` query param; native clients may
|
||||||
|
still use a bearer header. Either way it's the same access token.
|
||||||
|
"""
|
||||||
|
raw = token or (credentials.credentials if credentials else None)
|
||||||
|
if not raw:
|
||||||
|
raise AuthenticationError("Missing access token.")
|
||||||
|
return await auth.authenticate_access(raw)
|
||||||
|
|
||||||
|
|
||||||
|
StreamUser = Annotated[User, Depends(get_streaming_user)]
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Annotated
|
|||||||
from fastapi import APIRouter, Header
|
from fastapi import APIRouter, Header
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.api.deps import StreamingServiceDep
|
from app.api.deps import StreamingServiceDep, StreamUser
|
||||||
|
|
||||||
router = APIRouter(prefix="/stream", tags=["streaming"])
|
router = APIRouter(prefix="/stream", tags=["streaming"])
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ router = APIRouter(prefix="/stream", tags=["streaming"])
|
|||||||
async def stream_track(
|
async def stream_track(
|
||||||
track_id: uuid.UUID,
|
track_id: uuid.UUID,
|
||||||
service: StreamingServiceDep,
|
service: StreamingServiceDep,
|
||||||
|
_user: StreamUser,
|
||||||
range_header: Annotated[str | None, Header(alias="Range")] = None,
|
range_header: Annotated[str | None, Header(alias="Range")] = None,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
result = await service.open_stream(track_id, range_header)
|
result = await service.open_stream(track_id, range_header)
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ async def delete_track(
|
|||||||
if track is None:
|
if track is None:
|
||||||
raise NotFoundError(f"Track {track_id} not found.")
|
raise NotFoundError(f"Track {track_id} not found.")
|
||||||
await track_repo.delete(track_id)
|
await track_repo.delete(track_id)
|
||||||
await storage.delete(track.file_path)
|
await storage.delete(track.storage_uri)
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class StreamingService:
|
|||||||
if track is None:
|
if track is None:
|
||||||
raise NotFoundError("Track not found.")
|
raise NotFoundError("Track not found.")
|
||||||
|
|
||||||
stat = await self._storage.stat(track.file_path)
|
stat = await self._storage.stat(track.storage_uri)
|
||||||
total_size = stat.size
|
total_size = stat.size
|
||||||
content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get(
|
content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get(
|
||||||
track.file_format.lower(), "application/octet-stream"
|
track.file_format.lower(), "application/octet-stream"
|
||||||
@@ -81,7 +81,7 @@ class StreamingService:
|
|||||||
|
|
||||||
start, end, is_partial = _parse_range(range_header, total_size)
|
start, end, is_partial = _parse_range(range_header, total_size)
|
||||||
|
|
||||||
stream, _ = await self._storage.open_range(track.file_path, start, end)
|
stream, _ = await self._storage.open_range(track.storage_uri, start, end)
|
||||||
|
|
||||||
actual_end = end if end is not None else total_size - 1
|
actual_end = end if end is not None else total_size - 1
|
||||||
content_length = actual_end - start + 1
|
content_length = actual_end - start + 1
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class UploadService:
|
|||||||
id=track_id,
|
id=track_id,
|
||||||
title=title,
|
title=title,
|
||||||
artist_id=artist.id,
|
artist_id=artist.id,
|
||||||
file_path=key,
|
storage_uri=key,
|
||||||
file_format=ext,
|
file_format=ext,
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
source="upload",
|
source="upload",
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class Track:
|
|||||||
title: str
|
title: str
|
||||||
artist_id: uuid.UUID
|
artist_id: uuid.UUID
|
||||||
album_id: uuid.UUID | None
|
album_id: uuid.UUID | None
|
||||||
file_path: str
|
storage_uri: str
|
||||||
file_format: str
|
file_format: str
|
||||||
file_size: int
|
file_size: int
|
||||||
source: str
|
source: str
|
||||||
|
|||||||
+1
-1
@@ -100,7 +100,7 @@ class TrackRepository(Protocol):
|
|||||||
id: uuid.UUID,
|
id: uuid.UUID,
|
||||||
title: str,
|
title: str,
|
||||||
artist_id: uuid.UUID,
|
artist_id: uuid.UUID,
|
||||||
file_path: str,
|
storage_uri: str,
|
||||||
file_format: str,
|
file_format: str,
|
||||||
file_size: int,
|
file_size: int,
|
||||||
source: str,
|
source: str,
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TrackModel(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
|||||||
year: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
year: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
|
||||||
# -- file (original, stored as-is) -----------------------------------
|
# -- file (original, stored as-is) -----------------------------------
|
||||||
file_path: Mapped[str] = mapped_column(String(2048), nullable=False)
|
storage_uri: Mapped[str] = mapped_column(String(2048), nullable=False)
|
||||||
file_format: Mapped[str] = mapped_column(String(32), nullable=False)
|
file_format: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False)
|
file_size: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
bitrate: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
bitrate: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def _track_to_entity(row: TrackModel) -> Track:
|
|||||||
title=row.title,
|
title=row.title,
|
||||||
artist_id=row.artist_id,
|
artist_id=row.artist_id,
|
||||||
album_id=row.album_id,
|
album_id=row.album_id,
|
||||||
file_path=row.file_path,
|
storage_uri=row.storage_uri,
|
||||||
file_format=row.file_format,
|
file_format=row.file_format,
|
||||||
file_size=row.file_size,
|
file_size=row.file_size,
|
||||||
source=row.source,
|
source=row.source,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def _track_to_entity(row: TrackModel) -> Track:
|
|||||||
title=row.title,
|
title=row.title,
|
||||||
artist_id=row.artist_id,
|
artist_id=row.artist_id,
|
||||||
album_id=row.album_id,
|
album_id=row.album_id,
|
||||||
file_path=row.file_path,
|
storage_uri=row.storage_uri,
|
||||||
file_format=row.file_format,
|
file_format=row.file_format,
|
||||||
file_size=row.file_size,
|
file_size=row.file_size,
|
||||||
source=row.source,
|
source=row.source,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def _to_entity(row: TrackModel) -> Track:
|
|||||||
title=row.title,
|
title=row.title,
|
||||||
artist_id=row.artist_id,
|
artist_id=row.artist_id,
|
||||||
album_id=row.album_id,
|
album_id=row.album_id,
|
||||||
file_path=row.file_path,
|
storage_uri=row.storage_uri,
|
||||||
file_format=row.file_format,
|
file_format=row.file_format,
|
||||||
file_size=row.file_size,
|
file_size=row.file_size,
|
||||||
source=row.source,
|
source=row.source,
|
||||||
@@ -56,7 +56,7 @@ class SqlAlchemyTrackRepository:
|
|||||||
id: uuid.UUID,
|
id: uuid.UUID,
|
||||||
title: str,
|
title: str,
|
||||||
artist_id: uuid.UUID,
|
artist_id: uuid.UUID,
|
||||||
file_path: str,
|
storage_uri: str,
|
||||||
file_format: str,
|
file_format: str,
|
||||||
file_size: int,
|
file_size: int,
|
||||||
source: str,
|
source: str,
|
||||||
@@ -68,7 +68,7 @@ class SqlAlchemyTrackRepository:
|
|||||||
id=id,
|
id=id,
|
||||||
title=title,
|
title=title,
|
||||||
artist_id=artist_id,
|
artist_id=artist_id,
|
||||||
file_path=file_path,
|
storage_uri=storage_uri,
|
||||||
file_format=file_format,
|
file_format=file_format,
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
source=source,
|
source=source,
|
||||||
|
|||||||
@@ -1,16 +1,31 @@
|
|||||||
"""File storage provider — singleton factory."""
|
"""File storage provider — singleton factory wired from config."""
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
from app.domain.ports import FileStorage
|
||||||
from app.infrastructure.storage.local import LocalFileStorage
|
from app.infrastructure.storage.local import LocalFileStorage
|
||||||
|
from app.infrastructure.storage.s3 import S3FileStorage
|
||||||
|
|
||||||
_storage: LocalFileStorage | None = None
|
_storage: FileStorage | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_file_storage() -> LocalFileStorage:
|
def get_file_storage() -> FileStorage:
|
||||||
global _storage
|
global _storage
|
||||||
if _storage is None:
|
if _storage is None:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if settings.storage_backend == "s3":
|
if settings.storage_backend == "s3":
|
||||||
raise NotImplementedError("S3 storage not yet implemented.")
|
if not settings.s3_bucket:
|
||||||
_storage = LocalFileStorage(settings.media_path)
|
raise RuntimeError("S3_BUCKET must be set when STORAGE_BACKEND=s3")
|
||||||
|
_storage = S3FileStorage(
|
||||||
|
settings.s3_bucket,
|
||||||
|
endpoint_url=settings.s3_endpoint_url,
|
||||||
|
region_name=settings.s3_region,
|
||||||
|
access_key=settings.s3_access_key.get_secret_value()
|
||||||
|
if settings.s3_access_key
|
||||||
|
else None,
|
||||||
|
secret_key=settings.s3_secret_key.get_secret_value()
|
||||||
|
if settings.s3_secret_key
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_storage = LocalFileStorage(settings.media_path)
|
||||||
return _storage
|
return _storage
|
||||||
|
|||||||
@@ -0,0 +1,157 @@
|
|||||||
|
"""S3FileStorage — stores files in any S3-compatible object store."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aioboto3
|
||||||
|
import anyio
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.domain.entities.storage import ObjectStat
|
||||||
|
from app.domain.errors import StorageError
|
||||||
|
|
||||||
|
_EXT_CONTENT_TYPE: dict[str, str] = {
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"m4a": "audio/mp4",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"ogg": "audio/ogg",
|
||||||
|
"opus": "audio/ogg",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"wma": "audio/x-ms-wma",
|
||||||
|
"aiff": "audio/aiff",
|
||||||
|
"aif": "audio/aiff",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _not_found(key: str, exc: Exception) -> StorageError:
|
||||||
|
err = StorageError(f"Object not found: {key}")
|
||||||
|
err.__cause__ = exc
|
||||||
|
return err
|
||||||
|
|
||||||
|
|
||||||
|
def _is_404(exc: ClientError) -> bool:
|
||||||
|
return exc.response["Error"]["Code"] in ("404", "NoSuchKey")
|
||||||
|
|
||||||
|
|
||||||
|
class S3FileStorage:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bucket: str,
|
||||||
|
*,
|
||||||
|
endpoint_url: str | None = None,
|
||||||
|
region_name: str | None = None,
|
||||||
|
access_key: str | None = None,
|
||||||
|
secret_key: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._bucket = bucket
|
||||||
|
self._endpoint_url = endpoint_url
|
||||||
|
self._session: Any = aioboto3.Session(
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
return self._session.client("s3", endpoint_url=self._endpoint_url)
|
||||||
|
|
||||||
|
async def save_file(self, key: str, src_path: Path) -> int:
|
||||||
|
async with await anyio.open_file(src_path, "rb") as f:
|
||||||
|
content: bytes = await f.read()
|
||||||
|
async with self._client() as s3:
|
||||||
|
await s3.put_object(Bucket=self._bucket, Key=key, Body=content)
|
||||||
|
return len(content)
|
||||||
|
|
||||||
|
async def open_range(
|
||||||
|
self, key: str, start: int, end: int | None
|
||||||
|
) -> tuple[AsyncIterator[bytes], int]:
|
||||||
|
async with self._client() as s3:
|
||||||
|
try:
|
||||||
|
head = await s3.head_object(Bucket=self._bucket, Key=key)
|
||||||
|
except ClientError as exc:
|
||||||
|
if _is_404(exc):
|
||||||
|
raise _not_found(key, exc) from exc
|
||||||
|
raise StorageError(str(exc)) from exc
|
||||||
|
total_size: int = head["ContentLength"]
|
||||||
|
|
||||||
|
range_header = f"bytes={start}-{end}" if end is not None else f"bytes={start}-"
|
||||||
|
_bucket = self._bucket
|
||||||
|
_key = key
|
||||||
|
|
||||||
|
async def _stream() -> AsyncGenerator[bytes]:
|
||||||
|
async with self._client() as s3:
|
||||||
|
try:
|
||||||
|
resp = await s3.get_object(
|
||||||
|
Bucket=_bucket, Key=_key, Range=range_header
|
||||||
|
)
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(str(exc)) from exc
|
||||||
|
body = resp["Body"]
|
||||||
|
while True:
|
||||||
|
chunk: bytes = await body.read(65536)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return _stream(), total_size
|
||||||
|
|
||||||
|
async def stat(self, key: str) -> ObjectStat:
|
||||||
|
async with self._client() as s3:
|
||||||
|
try:
|
||||||
|
head = await s3.head_object(Bucket=self._bucket, Key=key)
|
||||||
|
except ClientError as exc:
|
||||||
|
if _is_404(exc):
|
||||||
|
raise _not_found(key, exc) from exc
|
||||||
|
raise StorageError(str(exc)) from exc
|
||||||
|
ext = Path(key).suffix.lower().lstrip(".")
|
||||||
|
return ObjectStat(
|
||||||
|
size=head["ContentLength"],
|
||||||
|
content_type=head.get("ContentType") or _EXT_CONTENT_TYPE.get(ext),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
async with self._client() as s3:
|
||||||
|
try:
|
||||||
|
await s3.head_object(Bucket=self._bucket, Key=key)
|
||||||
|
return True
|
||||||
|
except ClientError as exc:
|
||||||
|
if _is_404(exc):
|
||||||
|
return False
|
||||||
|
raise StorageError(str(exc)) from exc
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
async with self._client() as s3:
|
||||||
|
try:
|
||||||
|
await s3.delete_object(Bucket=self._bucket, Key=key)
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(str(exc)) from exc
|
||||||
|
|
||||||
|
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]:
|
||||||
|
return self._as_local_path_cm(key)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]:
|
||||||
|
suffix = Path(key).suffix
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
||||||
|
tmp_path = Path(f.name)
|
||||||
|
try:
|
||||||
|
async with self._client() as s3:
|
||||||
|
try:
|
||||||
|
resp = await s3.get_object(Bucket=self._bucket, Key=key)
|
||||||
|
except ClientError as exc:
|
||||||
|
if _is_404(exc):
|
||||||
|
raise _not_found(key, exc) from exc
|
||||||
|
raise StorageError(str(exc)) from exc
|
||||||
|
async with await anyio.open_file(tmp_path, "wb") as out:
|
||||||
|
body = resp["Body"]
|
||||||
|
while True:
|
||||||
|
chunk: bytes = await body.read(65536)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
await out.write(chunk)
|
||||||
|
yield tmp_path
|
||||||
|
finally:
|
||||||
|
await anyio.Path(tmp_path).unlink(missing_ok=True)
|
||||||
@@ -23,6 +23,8 @@ dependencies = [
|
|||||||
"pwdlib[argon2]>=0.2.1",
|
"pwdlib[argon2]>=0.2.1",
|
||||||
# outbound http (ML client, MusicBrainz, AcoustID)
|
# outbound http (ML client, MusicBrainz, AcoustID)
|
||||||
"httpx>=0.28",
|
"httpx>=0.28",
|
||||||
|
# S3-compatible object storage
|
||||||
|
"aioboto3>=13.0",
|
||||||
# logging
|
# logging
|
||||||
"structlog>=24.4",
|
"structlog>=24.4",
|
||||||
]
|
]
|
||||||
@@ -80,6 +82,10 @@ disallow_untyped_defs = true
|
|||||||
module = ["arq.*"]
|
module = ["arq.*"]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["aioboto3.*", "aiobotocore.*", "botocore.*"]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|||||||
@@ -0,0 +1,254 @@
|
|||||||
|
"""Unit tests for S3FileStorage — all S3 calls are mocked."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from app.domain.errors import StorageError
|
||||||
|
from app.infrastructure.storage.s3 import S3FileStorage
|
||||||
|
|
||||||
|
|
||||||
|
def _make_storage(**kwargs: Any) -> S3FileStorage:
|
||||||
|
return S3FileStorage("test-bucket", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _client_error(code: str) -> Exception:
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
return ClientError({"Error": {"Code": code, "Message": code}}, "op")
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeBody:
|
||||||
|
"""Async-iterable body that yields chunks from a bytes buffer."""
|
||||||
|
|
||||||
|
def __init__(self, data: bytes, chunk_size: int = 65536) -> None:
|
||||||
|
self._buf = io.BytesIO(data)
|
||||||
|
self._chunk_size = chunk_size
|
||||||
|
|
||||||
|
async def read(self, size: int = -1) -> bytes:
|
||||||
|
return self._buf.read(size)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client_ctx(s3_mock: Any) -> Any:
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__aenter__ = AsyncMock(return_value=s3_mock)
|
||||||
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def storage() -> S3FileStorage:
|
||||||
|
return _make_storage()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# save_file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_save_file_calls_put_object(tmp_path: Path, storage: S3FileStorage) -> None:
|
||||||
|
src = tmp_path / "track.mp3"
|
||||||
|
src.write_bytes(b"audio bytes")
|
||||||
|
|
||||||
|
s3 = AsyncMock()
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
size = await storage.save_file("tracks/ab/track.mp3", src)
|
||||||
|
|
||||||
|
s3.put_object.assert_awaited_once_with(
|
||||||
|
Bucket="test-bucket", Key="tracks/ab/track.mp3", Body=b"audio bytes"
|
||||||
|
)
|
||||||
|
assert size == 11
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# stat
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stat_returns_size_and_content_type(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.return_value = {"ContentLength": 1024, "ContentType": "audio/mpeg"}
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
stat = await storage.stat("tracks/ab/track.mp3")
|
||||||
|
|
||||||
|
assert stat.size == 1024
|
||||||
|
assert stat.content_type == "audio/mpeg"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stat_falls_back_to_ext_content_type(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.return_value = {"ContentLength": 500, "ContentType": None}
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
stat = await storage.stat("tracks/ab/track.flac")
|
||||||
|
|
||||||
|
assert stat.content_type == "audio/flac"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stat_not_found_raises_storage_error(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.side_effect = _client_error("404")
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
with pytest.raises(StorageError, match="not found"):
|
||||||
|
await storage.stat("tracks/missing.mp3")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# exists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_exists_true(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.return_value = {"ContentLength": 1}
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
assert await storage.exists("tracks/ab/track.mp3") is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_exists_false_on_404(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.side_effect = _client_error("NoSuchKey")
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
assert await storage.exists("tracks/missing.mp3") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# delete
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_calls_delete_object(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
await storage.delete("tracks/ab/track.mp3")
|
||||||
|
|
||||||
|
s3.delete_object.assert_awaited_once_with(Bucket="test-bucket", Key="tracks/ab/track.mp3")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# open_range
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_range_full(storage: S3FileStorage) -> None:
|
||||||
|
data = b"hello world"
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.return_value = {"ContentLength": len(data)}
|
||||||
|
s3.get_object.return_value = {"Body": _FakeBody(data)}
|
||||||
|
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
stream, total = await storage.open_range("tracks/ab/t.mp3", 0, None)
|
||||||
|
chunks = [c async for c in stream]
|
||||||
|
|
||||||
|
assert b"".join(chunks) == data
|
||||||
|
assert total == len(data)
|
||||||
|
s3.get_object.assert_awaited_once_with(
|
||||||
|
Bucket="test-bucket", Key="tracks/ab/t.mp3", Range="bytes=0-"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_range_partial(storage: S3FileStorage) -> None:
|
||||||
|
full = b"0123456789"
|
||||||
|
ranged = b"34567"
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.return_value = {"ContentLength": len(full)}
|
||||||
|
s3.get_object.return_value = {"Body": _FakeBody(ranged)}
|
||||||
|
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
stream, total = await storage.open_range("tracks/ab/t.mp3", 3, 7)
|
||||||
|
result = b"".join([c async for c in stream])
|
||||||
|
|
||||||
|
assert result == ranged
|
||||||
|
assert total == len(full)
|
||||||
|
s3.get_object.assert_awaited_once_with(
|
||||||
|
Bucket="test-bucket", Key="tracks/ab/t.mp3", Range="bytes=3-7"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_range_not_found_raises_storage_error(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.head_object.side_effect = _client_error("NoSuchKey")
|
||||||
|
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
with pytest.raises(StorageError, match="not found"):
|
||||||
|
await storage.open_range("tracks/missing.mp3", 0, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# as_local_path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_as_local_path_yields_file_with_content(storage: S3FileStorage) -> None:
|
||||||
|
data = b"local copy bytes"
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.get_object.return_value = {"Body": _FakeBody(data)}
|
||||||
|
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
async with storage.as_local_path("tracks/ab/track.mp3") as path:
|
||||||
|
assert path.exists()
|
||||||
|
assert path.read_bytes() == data
|
||||||
|
|
||||||
|
assert not path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_as_local_path_cleans_up_on_error(storage: S3FileStorage) -> None:
|
||||||
|
s3 = AsyncMock()
|
||||||
|
s3.get_object.side_effect = _client_error("NoSuchKey")
|
||||||
|
|
||||||
|
captured: list[Path] = []
|
||||||
|
|
||||||
|
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
async with storage.as_local_path("tracks/missing.mp3") as path:
|
||||||
|
captured.append(path)
|
||||||
|
|
||||||
|
if captured:
|
||||||
|
assert not captured[0].exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# provider wiring
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_returns_s3_storage_when_configured(tmp_path: Path) -> None:
|
||||||
|
from app.core.config import Settings
|
||||||
|
from app.infrastructure.storage import provider
|
||||||
|
|
||||||
|
provider._storage = None
|
||||||
|
|
||||||
|
mock_settings = Settings(
|
||||||
|
database_url="postgresql+asyncpg://x:x@localhost/x",
|
||||||
|
storage_backend="s3",
|
||||||
|
s3_bucket="my-bucket",
|
||||||
|
s3_region="us-east-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.infrastructure.storage.provider.get_settings", return_value=mock_settings):
|
||||||
|
storage_instance = provider.get_file_storage()
|
||||||
|
|
||||||
|
assert isinstance(storage_instance, S3FileStorage)
|
||||||
|
provider._storage = None # reset singleton for other tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_raises_when_s3_bucket_missing() -> None:
|
||||||
|
from app.core.config import Settings
|
||||||
|
from app.infrastructure.storage import provider
|
||||||
|
|
||||||
|
provider._storage = None
|
||||||
|
|
||||||
|
mock_settings = Settings(
|
||||||
|
database_url="postgresql+asyncpg://x:x@localhost/x",
|
||||||
|
storage_backend="s3",
|
||||||
|
s3_bucket=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.infrastructure.storage.provider.get_settings", return_value=mock_settings):
|
||||||
|
with pytest.raises(RuntimeError, match="S3_BUCKET"):
|
||||||
|
provider.get_file_storage()
|
||||||
|
|
||||||
|
provider._storage = None
|
||||||
@@ -143,7 +143,8 @@ async def test_stream_full(api: AsyncClient) -> None:
|
|||||||
assert up.status_code == 200
|
assert up.status_code == 200
|
||||||
track_id = up.json()["track_id"]
|
track_id = up.json()["track_id"]
|
||||||
|
|
||||||
resp = await api.get(f"/api/v1/stream/{track_id}")
|
# Browser <audio> can't send headers — auth rides on the ?token= query param.
|
||||||
|
resp = await api.get(f"/api/v1/stream/{track_id}?token={token}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.content == audio
|
assert resp.content == audio
|
||||||
assert resp.headers["content-type"].startswith("audio/mpeg")
|
assert resp.headers["content-type"].startswith("audio/mpeg")
|
||||||
@@ -164,7 +165,7 @@ async def test_stream_range(api: AsyncClient) -> None:
|
|||||||
|
|
||||||
resp = await api.get(
|
resp = await api.get(
|
||||||
f"/api/v1/stream/{track_id}",
|
f"/api/v1/stream/{track_id}",
|
||||||
headers={"Range": "bytes=0-9"},
|
headers={"Range": "bytes=0-9", "Authorization": f"Bearer {token}"},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 206
|
assert resp.status_code == 206
|
||||||
assert resp.content == b"0123456789"
|
assert resp.content == b"0123456789"
|
||||||
@@ -173,10 +174,18 @@ async def test_stream_range(api: AsyncClient) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_stream_not_found(api: AsyncClient) -> None:
|
async def test_stream_not_found(api: AsyncClient) -> None:
|
||||||
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
|
token = await _login(api)
|
||||||
|
resp = await api.get(
|
||||||
|
f"/api/v1/stream/00000000-0000-0000-0000-000000000000?token={token}"
|
||||||
|
)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stream_requires_auth(api: AsyncClient) -> None:
|
||||||
|
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
async def test_upload_requires_auth(api: AsyncClient) -> None:
|
async def test_upload_requires_auth(api: AsyncClient) -> None:
|
||||||
resp = await api.post(
|
resp = await api.post(
|
||||||
"/api/v1/upload",
|
"/api/v1/upload",
|
||||||
|
|||||||
Reference in New Issue
Block a user