feat: local storage logic & endpoints
This commit is contained in:
@@ -10,7 +10,7 @@ from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
from app.core.security import Argon2PasswordHasher
|
||||
from app.infrastructure.db import Base, get_engine, session_scope
|
||||
from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
@@ -73,6 +73,7 @@ async def api() -> AsyncIterator[AsyncClient]:
|
||||
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await dispose_engine()
|
||||
|
||||
|
||||
async def _login(api: AsyncClient, username: str, password: str) -> tuple[str, str]:
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Unit tests for LocalFileStorage."""
|
||||
|
||||
import pytest
|
||||
from app.infrastructure.storage.local import LocalFileStorage
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_save_and_stat(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.mp3"
|
||||
src.write_bytes(b"test audio data")
|
||||
|
||||
size = await storage.save_file("tracks/te/test.mp3", src)
|
||||
assert size == 15
|
||||
|
||||
stat = await storage.stat("tracks/te/test.mp3")
|
||||
assert stat.size == 15
|
||||
assert stat.content_type == "audio/mpeg"
|
||||
|
||||
|
||||
async def test_save_creates_parent_dirs(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.flac"
|
||||
src.write_bytes(b"x")
|
||||
|
||||
await storage.save_file("tracks/ab/cdef.flac", src)
|
||||
assert (tmp_path / "tracks" / "ab" / "cdef.flac").exists()
|
||||
|
||||
|
||||
async def test_open_range_full(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
data = b"hello world" * 100
|
||||
src = tmp_path / "src.flac"
|
||||
src.write_bytes(data)
|
||||
await storage.save_file("tracks/he/hello.flac", src)
|
||||
|
||||
stream, total = await storage.open_range("tracks/he/hello.flac", 0, None)
|
||||
result = b"".join([chunk async for chunk in stream])
|
||||
assert result == data
|
||||
assert total == len(data)
|
||||
|
||||
|
||||
async def test_open_range_partial(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
data = b"0123456789"
|
||||
src = tmp_path / "src.mp3"
|
||||
src.write_bytes(data)
|
||||
await storage.save_file("tracks/sr/src.mp3", src)
|
||||
|
||||
stream, total = await storage.open_range("tracks/sr/src.mp3", 3, 7)
|
||||
result = b"".join([chunk async for chunk in stream])
|
||||
assert result == b"34567"
|
||||
assert total == 10
|
||||
|
||||
|
||||
async def test_open_range_from_offset_to_end(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
data = b"abcdefghij"
|
||||
src = tmp_path / "src.wav"
|
||||
src.write_bytes(data)
|
||||
await storage.save_file("tracks/sr/src.wav", src)
|
||||
|
||||
stream, total = await storage.open_range("tracks/sr/src.wav", 5, None)
|
||||
result = b"".join([chunk async for chunk in stream])
|
||||
assert result == b"fghij"
|
||||
assert total == 10
|
||||
|
||||
|
||||
async def test_exists_and_delete(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.ogg"
|
||||
src.write_bytes(b"ogg data")
|
||||
await storage.save_file("tracks/sr/src.ogg", src)
|
||||
|
||||
assert await storage.exists("tracks/sr/src.ogg") is True
|
||||
await storage.delete("tracks/sr/src.ogg")
|
||||
assert await storage.exists("tracks/sr/src.ogg") is False
|
||||
|
||||
|
||||
async def test_delete_missing_is_noop(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
await storage.delete("tracks/no/nope.mp3")
|
||||
|
||||
|
||||
async def test_as_local_path(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.mp3"
|
||||
src.write_bytes(b"local bytes")
|
||||
await storage.save_file("tracks/lo/local.mp3", src)
|
||||
|
||||
async with storage.as_local_path("tracks/lo/local.mp3") as path:
|
||||
assert path.read_bytes() == b"local bytes"
|
||||
|
||||
|
||||
async def test_stat_unknown_extension(tmp_path):
|
||||
storage = LocalFileStorage(tmp_path)
|
||||
src = tmp_path / "src.xyz"
|
||||
src.write_bytes(b"mystery")
|
||||
await storage.save_file("tracks/my/mystery.xyz", src)
|
||||
|
||||
stat = await storage.stat("tracks/my/mystery.xyz")
|
||||
assert stat.size == 7
|
||||
assert stat.content_type is None
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Integration tests for upload and streaming endpoints.
|
||||
|
||||
Requires a reachable Postgres; skips otherwise.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from app.core.config import get_settings
|
||||
from app.infrastructure.db import Base, dispose_engine, get_engine, session_scope
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
)
|
||||
from asgi_lifespan import LifespanManager
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
_db_reachable_cache: bool | None = None
|
||||
|
||||
|
||||
async def _db_reachable() -> bool:
|
||||
global _db_reachable_cache
|
||||
if _db_reachable_cache is not None:
|
||||
return _db_reachable_cache
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(3):
|
||||
async with get_engine().connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
_db_reachable_cache = True
|
||||
except Exception:
|
||||
_db_reachable_cache = False
|
||||
return _db_reachable_cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def api(tmp_path: Path) -> AsyncIterator[AsyncClient]:
|
||||
if not await _db_reachable():
|
||||
pytest.skip("Postgres not reachable — integration test skipped.")
|
||||
|
||||
os.environ["MEDIA_PATH"] = str(tmp_path)
|
||||
get_settings.cache_clear()
|
||||
|
||||
# Also reset the file storage singleton so it picks up the new media_path
|
||||
import app.infrastructure.storage.provider as _storage_provider
|
||||
|
||||
_storage_provider._storage = None
|
||||
|
||||
try:
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
from app.application.user_service import UserService
|
||||
from app.core.security import Argon2PasswordHasher
|
||||
|
||||
async with session_scope() as session:
|
||||
await UserService(
|
||||
users=SqlAlchemyUserRepository(session),
|
||||
refresh_tokens=SqlAlchemyRefreshTokenRepository(session),
|
||||
hasher=Argon2PasswordHasher(),
|
||||
).create_user(username="testuser", password="testpass1", is_superuser=False)
|
||||
|
||||
from app.main import create_app
|
||||
|
||||
app = create_app()
|
||||
async with LifespanManager(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await dispose_engine()
|
||||
finally:
|
||||
_storage_provider._storage = None
|
||||
os.environ.pop("MEDIA_PATH", None)
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
async def _login(api: AsyncClient) -> str:
|
||||
resp = await api.post(
|
||||
"/api/v1/auth/login", json={"username": "testuser", "password": "testpass1"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return str(resp.json()["access_token"])
|
||||
|
||||
|
||||
async def test_upload_creates_track(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"fake mp3 bytes" * 100
|
||||
|
||||
resp = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("song.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["already_exists"] is False
|
||||
assert body["title"] == "song"
|
||||
assert "track_id" in body
|
||||
|
||||
|
||||
async def test_upload_dedup(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"same content" * 50
|
||||
|
||||
first = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("a.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert first.status_code == 200
|
||||
assert first.json()["already_exists"] is False
|
||||
|
||||
second = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("b.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert second.status_code == 200
|
||||
assert second.json()["already_exists"] is True
|
||||
assert second.json()["track_id"] == first.json()["track_id"]
|
||||
|
||||
|
||||
async def test_stream_full(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"audio data for streaming" * 10
|
||||
|
||||
up = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("track.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert up.status_code == 200
|
||||
track_id = up.json()["track_id"]
|
||||
|
||||
resp = await api.get(f"/api/v1/stream/{track_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == audio
|
||||
assert resp.headers["content-type"].startswith("audio/mpeg")
|
||||
assert "accept-ranges" in resp.headers
|
||||
|
||||
|
||||
async def test_stream_range(api: AsyncClient) -> None:
|
||||
token = await _login(api)
|
||||
audio = b"0123456789" * 10
|
||||
|
||||
up = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("range.mp3", audio, "audio/mpeg")},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert up.status_code == 200
|
||||
track_id = up.json()["track_id"]
|
||||
|
||||
resp = await api.get(
|
||||
f"/api/v1/stream/{track_id}",
|
||||
headers={"Range": "bytes=0-9"},
|
||||
)
|
||||
assert resp.status_code == 206
|
||||
assert resp.content == b"0123456789"
|
||||
assert resp.headers["content-range"] == f"bytes 0-9/{len(audio)}"
|
||||
assert resp.headers["content-length"] == "10"
|
||||
|
||||
|
||||
async def test_stream_not_found(api: AsyncClient) -> None:
|
||||
resp = await api.get("/api/v1/stream/00000000-0000-0000-0000-000000000000")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_upload_requires_auth(api: AsyncClient) -> None:
|
||||
resp = await api.post(
|
||||
"/api/v1/upload",
|
||||
files={"file": ("x.mp3", b"data", "audio/mpeg")},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
Reference in New Issue
Block a user