Compare commits
2 Commits
a8348e145a
...
4ade6939b6
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ade6939b6 | |||
| 5c5df5d3cc |
@@ -0,0 +1,25 @@
|
||||
"""rename track file_path to storage_uri
|
||||
|
||||
Revision ID: 20260608_storage_uri
|
||||
Revises: e670d6c41d0c
|
||||
Create Date: 2026-06-08 11:32:00.000000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "20260608_storage_uri"
|
||||
down_revision: str | None = "e670d6c41d0c"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("tracks", "file_path", new_column_name="storage_uri")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("tracks", "storage_uri", new_column_name="file_path")
|
||||
@@ -167,3 +167,23 @@ async def get_current_superuser(user: CurrentUser) -> User:
|
||||
|
||||
|
||||
SuperUser = Annotated[User, Depends(get_current_superuser)]
|
||||
|
||||
|
||||
async def get_streaming_user(
|
||||
auth: AuthServiceDep,
|
||||
credentials: BearerDep,
|
||||
token: str | None = None,
|
||||
) -> User:
|
||||
"""Authenticate a stream request.
|
||||
|
||||
The browser ``<audio>`` element cannot send an ``Authorization`` header, so
|
||||
the access token is accepted as a ``?token=`` query param; native clients may
|
||||
still use a bearer header. Either way it's the same access token.
|
||||
"""
|
||||
raw = token or (credentials.credentials if credentials else None)
|
||||
if not raw:
|
||||
raise AuthenticationError("Missing access token.")
|
||||
return await auth.authenticate_access(raw)
|
||||
|
||||
|
||||
StreamUser = Annotated[User, Depends(get_streaming_user)]
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.api.deps import StreamingServiceDep
|
||||
from app.api.deps import StreamingServiceDep, StreamUser
|
||||
|
||||
router = APIRouter(prefix="/stream", tags=["streaming"])
|
||||
|
||||
@@ -15,6 +15,7 @@ router = APIRouter(prefix="/stream", tags=["streaming"])
|
||||
async def stream_track(
|
||||
track_id: uuid.UUID,
|
||||
service: StreamingServiceDep,
|
||||
_user: StreamUser,
|
||||
range_header: Annotated[str | None, Header(alias="Range")] = None,
|
||||
) -> StreamingResponse:
|
||||
result = await service.open_stream(track_id, range_header)
|
||||
|
||||
@@ -130,7 +130,7 @@ async def delete_track(
|
||||
if track is None:
|
||||
raise NotFoundError(f"Track {track_id} not found.")
|
||||
await track_repo.delete(track_id)
|
||||
await storage.delete(track.file_path)
|
||||
await storage.delete(track.storage_uri)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class StreamingService:
|
||||
if track is None:
|
||||
raise NotFoundError("Track not found.")
|
||||
|
||||
stat = await self._storage.stat(track.file_path)
|
||||
stat = await self._storage.stat(track.storage_uri)
|
||||
total_size = stat.size
|
||||
content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get(
|
||||
track.file_format.lower(), "application/octet-stream"
|
||||
@@ -81,7 +81,7 @@ class StreamingService:
|
||||
|
||||
start, end, is_partial = _parse_range(range_header, total_size)
|
||||
|
||||
stream, _ = await self._storage.open_range(track.file_path, start, end)
|
||||
stream, _ = await self._storage.open_range(track.storage_uri, start, end)
|
||||
|
||||
actual_end = end if end is not None else total_size - 1
|
||||
content_length = actual_end - start + 1
|
||||
|
||||
@@ -92,7 +92,7 @@ class UploadService:
|
||||
id=track_id,
|
||||
title=title,
|
||||
artist_id=artist.id,
|
||||
file_path=key,
|
||||
storage_uri=key,
|
||||
file_format=ext,
|
||||
file_size=file_size,
|
||||
source="upload",
|
||||
|
||||
@@ -19,7 +19,7 @@ class Track:
|
||||
title: str
|
||||
artist_id: uuid.UUID
|
||||
album_id: uuid.UUID | None
|
||||
file_path: str
|
||||
storage_uri: str
|
||||
file_format: str
|
||||
file_size: int
|
||||
source: str
|
||||
|
||||
+1
-1
@@ -100,7 +100,7 @@ class TrackRepository(Protocol):
|
||||
id: uuid.UUID,
|
||||
title: str,
|
||||
artist_id: uuid.UUID,
|
||||
file_path: str,
|
||||
storage_uri: str,
|
||||
file_format: str,
|
||||
file_size: int,
|
||||
source: str,
|
||||
|
||||
@@ -40,7 +40,7 @@ class TrackModel(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
year: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# -- file (original, stored as-is) -----------------------------------
|
||||
file_path: Mapped[str] = mapped_column(String(2048), nullable=False)
|
||||
storage_uri: Mapped[str] = mapped_column(String(2048), nullable=False)
|
||||
file_format: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
bitrate: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
@@ -30,7 +30,7 @@ def _track_to_entity(row: TrackModel) -> Track:
|
||||
title=row.title,
|
||||
artist_id=row.artist_id,
|
||||
album_id=row.album_id,
|
||||
file_path=row.file_path,
|
||||
storage_uri=row.storage_uri,
|
||||
file_format=row.file_format,
|
||||
file_size=row.file_size,
|
||||
source=row.source,
|
||||
|
||||
@@ -29,7 +29,7 @@ def _track_to_entity(row: TrackModel) -> Track:
|
||||
title=row.title,
|
||||
artist_id=row.artist_id,
|
||||
album_id=row.album_id,
|
||||
file_path=row.file_path,
|
||||
storage_uri=row.storage_uri,
|
||||
file_format=row.file_format,
|
||||
file_size=row.file_size,
|
||||
source=row.source,
|
||||
|
||||
@@ -17,7 +17,7 @@ def _to_entity(row: TrackModel) -> Track:
|
||||
title=row.title,
|
||||
artist_id=row.artist_id,
|
||||
album_id=row.album_id,
|
||||
file_path=row.file_path,
|
||||
storage_uri=row.storage_uri,
|
||||
file_format=row.file_format,
|
||||
file_size=row.file_size,
|
||||
source=row.source,
|
||||
@@ -56,7 +56,7 @@ class SqlAlchemyTrackRepository:
|
||||
id: uuid.UUID,
|
||||
title: str,
|
||||
artist_id: uuid.UUID,
|
||||
file_path: str,
|
||||
storage_uri: str,
|
||||
file_format: str,
|
||||
file_size: int,
|
||||
source: str,
|
||||
@@ -68,7 +68,7 @@ class SqlAlchemyTrackRepository:
|
||||
id=id,
|
||||
title=title,
|
||||
artist_id=artist_id,
|
||||
file_path=file_path,
|
||||
storage_uri=storage_uri,
|
||||
file_format=file_format,
|
||||
file_size=file_size,
|
||||
source=source,
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
"""File storage provider — singleton factory."""
|
||||
"""File storage provider — singleton factory wired from config."""
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.domain.ports import FileStorage
|
||||
from app.infrastructure.storage.local import LocalFileStorage
|
||||
from app.infrastructure.storage.s3 import S3FileStorage
|
||||
|
||||
_storage: LocalFileStorage | None = None
|
||||
_storage: FileStorage | None = None
|
||||
|
||||
|
||||
def get_file_storage() -> LocalFileStorage:
|
||||
def get_file_storage() -> FileStorage:
|
||||
global _storage
|
||||
if _storage is None:
|
||||
settings = get_settings()
|
||||
if settings.storage_backend == "s3":
|
||||
raise NotImplementedError("S3 storage not yet implemented.")
|
||||
if not settings.s3_bucket:
|
||||
raise RuntimeError("S3_BUCKET must be set when STORAGE_BACKEND=s3")
|
||||
_storage = S3FileStorage(
|
||||
settings.s3_bucket,
|
||||
endpoint_url=settings.s3_endpoint_url,
|
||||
region_name=settings.s3_region,
|
||||
access_key=settings.s3_access_key.get_secret_value()
|
||||
if settings.s3_access_key
|
||||
else None,
|
||||
secret_key=settings.s3_secret_key.get_secret_value()
|
||||
if settings.s3_secret_key
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
_storage = LocalFileStorage(settings.media_path)
|
||||
return _storage
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
"""S3FileStorage — stores files in any S3-compatible object store."""
|
||||
|
||||
import tempfile
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aioboto3
|
||||
import anyio
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from app.domain.entities.storage import ObjectStat
|
||||
from app.domain.errors import StorageError
|
||||
|
||||
_EXT_CONTENT_TYPE: dict[str, str] = {
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
"aac": "audio/aac",
|
||||
"ogg": "audio/ogg",
|
||||
"opus": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"aiff": "audio/aiff",
|
||||
"aif": "audio/aiff",
|
||||
}
|
||||
|
||||
|
||||
def _not_found(key: str, exc: Exception) -> StorageError:
|
||||
err = StorageError(f"Object not found: {key}")
|
||||
err.__cause__ = exc
|
||||
return err
|
||||
|
||||
|
||||
def _is_404(exc: ClientError) -> bool:
|
||||
return exc.response["Error"]["Code"] in ("404", "NoSuchKey")
|
||||
|
||||
|
||||
class S3FileStorage:
|
||||
def __init__(
|
||||
self,
|
||||
bucket: str,
|
||||
*,
|
||||
endpoint_url: str | None = None,
|
||||
region_name: str | None = None,
|
||||
access_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
) -> None:
|
||||
self._bucket = bucket
|
||||
self._endpoint_url = endpoint_url
|
||||
self._session: Any = aioboto3.Session(
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
def _client(self) -> Any:
|
||||
return self._session.client("s3", endpoint_url=self._endpoint_url)
|
||||
|
||||
async def save_file(self, key: str, src_path: Path) -> int:
|
||||
async with await anyio.open_file(src_path, "rb") as f:
|
||||
content: bytes = await f.read()
|
||||
async with self._client() as s3:
|
||||
await s3.put_object(Bucket=self._bucket, Key=key, Body=content)
|
||||
return len(content)
|
||||
|
||||
async def open_range(
|
||||
self, key: str, start: int, end: int | None
|
||||
) -> tuple[AsyncIterator[bytes], int]:
|
||||
async with self._client() as s3:
|
||||
try:
|
||||
head = await s3.head_object(Bucket=self._bucket, Key=key)
|
||||
except ClientError as exc:
|
||||
if _is_404(exc):
|
||||
raise _not_found(key, exc) from exc
|
||||
raise StorageError(str(exc)) from exc
|
||||
total_size: int = head["ContentLength"]
|
||||
|
||||
range_header = f"bytes={start}-{end}" if end is not None else f"bytes={start}-"
|
||||
_bucket = self._bucket
|
||||
_key = key
|
||||
|
||||
async def _stream() -> AsyncGenerator[bytes]:
|
||||
async with self._client() as s3:
|
||||
try:
|
||||
resp = await s3.get_object(
|
||||
Bucket=_bucket, Key=_key, Range=range_header
|
||||
)
|
||||
except ClientError as exc:
|
||||
raise StorageError(str(exc)) from exc
|
||||
body = resp["Body"]
|
||||
while True:
|
||||
chunk: bytes = await body.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return _stream(), total_size
|
||||
|
||||
async def stat(self, key: str) -> ObjectStat:
|
||||
async with self._client() as s3:
|
||||
try:
|
||||
head = await s3.head_object(Bucket=self._bucket, Key=key)
|
||||
except ClientError as exc:
|
||||
if _is_404(exc):
|
||||
raise _not_found(key, exc) from exc
|
||||
raise StorageError(str(exc)) from exc
|
||||
ext = Path(key).suffix.lower().lstrip(".")
|
||||
return ObjectStat(
|
||||
size=head["ContentLength"],
|
||||
content_type=head.get("ContentType") or _EXT_CONTENT_TYPE.get(ext),
|
||||
)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
async with self._client() as s3:
|
||||
try:
|
||||
await s3.head_object(Bucket=self._bucket, Key=key)
|
||||
return True
|
||||
except ClientError as exc:
|
||||
if _is_404(exc):
|
||||
return False
|
||||
raise StorageError(str(exc)) from exc
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
async with self._client() as s3:
|
||||
try:
|
||||
await s3.delete_object(Bucket=self._bucket, Key=key)
|
||||
except ClientError as exc:
|
||||
raise StorageError(str(exc)) from exc
|
||||
|
||||
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]:
|
||||
return self._as_local_path_cm(key)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]:
|
||||
suffix = Path(key).suffix
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
||||
tmp_path = Path(f.name)
|
||||
try:
|
||||
async with self._client() as s3:
|
||||
try:
|
||||
resp = await s3.get_object(Bucket=self._bucket, Key=key)
|
||||
except ClientError as exc:
|
||||
if _is_404(exc):
|
||||
raise _not_found(key, exc) from exc
|
||||
raise StorageError(str(exc)) from exc
|
||||
async with await anyio.open_file(tmp_path, "wb") as out:
|
||||
body = resp["Body"]
|
||||
while True:
|
||||
chunk: bytes = await body.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
await out.write(chunk)
|
||||
yield tmp_path
|
||||
finally:
|
||||
await anyio.Path(tmp_path).unlink(missing_ok=True)
|
||||
@@ -23,6 +23,8 @@ dependencies = [
|
||||
"pwdlib[argon2]>=0.2.1",
|
||||
# outbound http (ML client, MusicBrainz, AcoustID)
|
||||
"httpx>=0.28",
|
||||
# S3-compatible object storage
|
||||
"aioboto3>=13.0",
|
||||
# logging
|
||||
"structlog>=24.4",
|
||||
]
|
||||
@@ -80,6 +82,10 @@ disallow_untyped_defs = true
|
||||
module = ["arq.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["aioboto3.*", "aiobotocore.*", "botocore.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
"""Unit tests for S3FileStorage — all S3 calls are mocked."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from app.domain.errors import StorageError
|
||||
from app.infrastructure.storage.s3 import S3FileStorage
|
||||
|
||||
|
||||
def _make_storage(**kwargs: Any) -> S3FileStorage:
|
||||
return S3FileStorage("test-bucket", **kwargs)
|
||||
|
||||
|
||||
def _client_error(code: str) -> Exception:
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
return ClientError({"Error": {"Code": code, "Message": code}}, "op")
|
||||
|
||||
|
||||
class _FakeBody:
|
||||
"""Async-iterable body that yields chunks from a bytes buffer."""
|
||||
|
||||
def __init__(self, data: bytes, chunk_size: int = 65536) -> None:
|
||||
self._buf = io.BytesIO(data)
|
||||
self._chunk_size = chunk_size
|
||||
|
||||
async def read(self, size: int = -1) -> bytes:
|
||||
return self._buf.read(size)
|
||||
|
||||
|
||||
def _make_client_ctx(s3_mock: Any) -> Any:
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=s3_mock)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage() -> S3FileStorage:
|
||||
return _make_storage()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# save_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_save_file_calls_put_object(tmp_path: Path, storage: S3FileStorage) -> None:
|
||||
src = tmp_path / "track.mp3"
|
||||
src.write_bytes(b"audio bytes")
|
||||
|
||||
s3 = AsyncMock()
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
size = await storage.save_file("tracks/ab/track.mp3", src)
|
||||
|
||||
s3.put_object.assert_awaited_once_with(
|
||||
Bucket="test-bucket", Key="tracks/ab/track.mp3", Body=b"audio bytes"
|
||||
)
|
||||
assert size == 11
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stat
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_stat_returns_size_and_content_type(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.return_value = {"ContentLength": 1024, "ContentType": "audio/mpeg"}
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
stat = await storage.stat("tracks/ab/track.mp3")
|
||||
|
||||
assert stat.size == 1024
|
||||
assert stat.content_type == "audio/mpeg"
|
||||
|
||||
|
||||
async def test_stat_falls_back_to_ext_content_type(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.return_value = {"ContentLength": 500, "ContentType": None}
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
stat = await storage.stat("tracks/ab/track.flac")
|
||||
|
||||
assert stat.content_type == "audio/flac"
|
||||
|
||||
|
||||
async def test_stat_not_found_raises_storage_error(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.side_effect = _client_error("404")
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
with pytest.raises(StorageError, match="not found"):
|
||||
await storage.stat("tracks/missing.mp3")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_exists_true(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.return_value = {"ContentLength": 1}
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
assert await storage.exists("tracks/ab/track.mp3") is True
|
||||
|
||||
|
||||
async def test_exists_false_on_404(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.side_effect = _client_error("NoSuchKey")
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
assert await storage.exists("tracks/missing.mp3") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_delete_calls_delete_object(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
await storage.delete("tracks/ab/track.mp3")
|
||||
|
||||
s3.delete_object.assert_awaited_once_with(Bucket="test-bucket", Key="tracks/ab/track.mp3")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# open_range
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_open_range_full(storage: S3FileStorage) -> None:
|
||||
data = b"hello world"
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.return_value = {"ContentLength": len(data)}
|
||||
s3.get_object.return_value = {"Body": _FakeBody(data)}
|
||||
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
stream, total = await storage.open_range("tracks/ab/t.mp3", 0, None)
|
||||
chunks = [c async for c in stream]
|
||||
|
||||
assert b"".join(chunks) == data
|
||||
assert total == len(data)
|
||||
s3.get_object.assert_awaited_once_with(
|
||||
Bucket="test-bucket", Key="tracks/ab/t.mp3", Range="bytes=0-"
|
||||
)
|
||||
|
||||
|
||||
async def test_open_range_partial(storage: S3FileStorage) -> None:
|
||||
full = b"0123456789"
|
||||
ranged = b"34567"
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.return_value = {"ContentLength": len(full)}
|
||||
s3.get_object.return_value = {"Body": _FakeBody(ranged)}
|
||||
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
stream, total = await storage.open_range("tracks/ab/t.mp3", 3, 7)
|
||||
result = b"".join([c async for c in stream])
|
||||
|
||||
assert result == ranged
|
||||
assert total == len(full)
|
||||
s3.get_object.assert_awaited_once_with(
|
||||
Bucket="test-bucket", Key="tracks/ab/t.mp3", Range="bytes=3-7"
|
||||
)
|
||||
|
||||
|
||||
async def test_open_range_not_found_raises_storage_error(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.head_object.side_effect = _client_error("NoSuchKey")
|
||||
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
with pytest.raises(StorageError, match="not found"):
|
||||
await storage.open_range("tracks/missing.mp3", 0, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# as_local_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_as_local_path_yields_file_with_content(storage: S3FileStorage) -> None:
|
||||
data = b"local copy bytes"
|
||||
s3 = AsyncMock()
|
||||
s3.get_object.return_value = {"Body": _FakeBody(data)}
|
||||
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
async with storage.as_local_path("tracks/ab/track.mp3") as path:
|
||||
assert path.exists()
|
||||
assert path.read_bytes() == data
|
||||
|
||||
assert not path.exists()
|
||||
|
||||
|
||||
async def test_as_local_path_cleans_up_on_error(storage: S3FileStorage) -> None:
|
||||
s3 = AsyncMock()
|
||||
s3.get_object.side_effect = _client_error("NoSuchKey")
|
||||
|
||||
captured: list[Path] = []
|
||||
|
||||
with patch.object(storage, "_client", return_value=_make_client_ctx(s3)):
|
||||
with pytest.raises(StorageError):
|
||||
async with storage.as_local_path("tracks/missing.mp3") as path:
|
||||
captured.append(path)
|
||||
|
||||
if captured:
|
||||
assert not captured[0].exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# provider wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_provider_returns_s3_storage_when_configured(tmp_path: Path) -> None:
|
||||
from app.core.config import Settings
|
||||
from app.infrastructure.storage import provider
|
||||
|
||||
provider._storage = None
|
||||
|
||||
mock_settings = Settings(
|
||||
database_url="postgresql+asyncpg://x:x@localhost/x",
|
||||
storage_backend="s3",
|
||||
s3_bucket="my-bucket",
|
||||
s3_region="us-east-1",
|
||||
)
|
||||
|
||||
with patch("app.infrastructure.storage.provider.get_settings", return_value=mock_settings):
|
||||
storage_instance = provider.get_file_storage()
|
||||
|
||||
assert isinstance(storage_instance, S3FileStorage)
|
||||
provider._storage = None # reset singleton for other tests
|
||||
|
||||
|
||||
def test_provider_raises_when_s3_bucket_missing() -> None:
|
||||
from app.core.config import Settings
|
||||
from app.infrastructure.storage import provider
|
||||
|
||||
provider._storage = None
|
||||
|
||||
mock_settings = Settings(
|
||||
database_url="postgresql+asyncpg://x:x@localhost/x",
|
||||
storage_backend="s3",
|
||||
s3_bucket=None,
|
||||
)
|
||||
|
||||
with patch("app.infrastructure.storage.provider.get_settings", return_value=mock_settings):
|
||||
with pytest.raises(RuntimeError, match="S3_BUCKET"):
|
||||
provider.get_file_storage()
|
||||
|
||||
provider._storage = None
|
||||
@@ -143,7 +143,8 @@ async def test_stream_full(api: AsyncClient) -> None:
|
||||
assert up.status_code == 200
|
||||
track_id = up.json()["track_id"]
|
||||
|
||||
resp = await api.get(f"/api/v1/stream/{track_id}")
|
||||
# Browser <audio> can't send headers — auth rides on the ?token= query param.
|
||||
resp = await api.get(f"/api/v1/stream/{track_id}?token={token}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == audio
|
||||
assert resp.headers["content-type"].startswith("audio/mpeg")
|
||||
@@ -164,7 +165,7 @@ async def test_stream_range(api: AsyncClient) -> None:
|
||||
|
||||
resp = await api.get(
|
||||
f"/api/v1/stream/{track_id}",
|
||||
headers={"Range": "bytes=0-9"},
|
||||
headers={"Range": "bytes=0-9", "Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 206
|
||||
assert resp.content == b"0123456789"
|
||||
@@ -173,10 +174,18 @@ async def test_stream_range(api: AsyncClient) -> None:
|
||||
|
||||
|
||||
async def test_stream_not_found(api: AsyncClient) -> None:
|
||||
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
|
||||
token = await _login(api)
|
||||
resp = await api.get(
|
||||
f"/api/v1/stream/00000000-0000-0000-0000-000000000000?token={token}"
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_stream_requires_auth(api: AsyncClient) -> None:
|
||||
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_upload_requires_auth(api: AsyncClient) -> None:
|
||||
resp = await api.post(
|
||||
"/api/v1/upload",
|
||||
|
||||
Reference in New Issue
Block a user