feat: local storage logic & endpoints
This commit is contained in:
+31
-1
@@ -15,17 +15,22 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.application.auth_service import AuthService
|
||||
from app.application.streaming_service import StreamingService
|
||||
from app.application.upload_service import UploadService
|
||||
from app.application.user_service import UserService
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import Argon2PasswordHasher, JwtTokenService
|
||||
from app.domain.entities import User
|
||||
from app.domain.errors import AuthenticationError, PermissionDeniedError
|
||||
from app.domain.ports import PasswordHasher, TokenService
|
||||
from app.domain.ports import FileStorage, PasswordHasher, TokenService
|
||||
from app.infrastructure.db import get_sessionmaker
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyArtistRepository,
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyTrackRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
)
|
||||
from app.infrastructure.storage.provider import get_file_storage
|
||||
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
@@ -77,6 +82,31 @@ AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
|
||||
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
|
||||
|
||||
|
||||
# -- file storage (process-cached) ---------------------------------------------
|
||||
FileStorageDep = Annotated[FileStorage, Depends(get_file_storage)]
|
||||
|
||||
|
||||
def get_upload_service(session: SessionDep, storage: FileStorageDep) -> UploadService:
|
||||
settings = get_settings()
|
||||
return UploadService(
|
||||
tracks=SqlAlchemyTrackRepository(session),
|
||||
artists=SqlAlchemyArtistRepository(session),
|
||||
storage=storage,
|
||||
tmp_dir=settings.upload_tmp_dir,
|
||||
)
|
||||
|
||||
|
||||
def get_streaming_service(session: SessionDep, storage: FileStorageDep) -> StreamingService:
|
||||
return StreamingService(
|
||||
tracks=SqlAlchemyTrackRepository(session),
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
|
||||
UploadServiceDep = Annotated[UploadService, Depends(get_upload_service)]
|
||||
StreamingServiceDep = Annotated[StreamingService, Depends(get_streaming_service)]
|
||||
|
||||
|
||||
# -- current user / authorization ----------------------------------------------
|
||||
# auto_error=False: we raise domain AuthenticationError (mapped to 401) so the
|
||||
# error envelope stays consistent with the rest of the API.
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.domain.errors import (
|
||||
DomainError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RangeNotSatisfiableError,
|
||||
StorageError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
@@ -25,6 +27,7 @@ _STATUS_BY_ERROR: dict[type[DomainError], int] = {
|
||||
AuthenticationError: status.HTTP_401_UNAUTHORIZED,
|
||||
PermissionDeniedError: status.HTTP_403_FORBIDDEN,
|
||||
DependencyUnavailableError: status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
StorageError: status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +36,14 @@ def _error_body(code: str, message: str) -> dict[str, dict[str, str]]:
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
@app.exception_handler(RangeNotSatisfiableError)
|
||||
async def _handle_range_error(_request: Request, exc: RangeNotSatisfiableError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
content=_error_body(exc.code, exc.message),
|
||||
headers={"Content-Range": f"bytes */{exc.total_size}"},
|
||||
)
|
||||
|
||||
@app.exception_handler(DomainError)
|
||||
async def _handle_domain_error(_request: Request, exc: DomainError) -> JSONResponse:
|
||||
http_status = _STATUS_BY_ERROR.get(type(exc), status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Schemas for upload responses."""
|
||||
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
track_id: uuid.UUID
|
||||
title: str
|
||||
already_exists: bool
|
||||
+27
-9
@@ -1,20 +1,38 @@
|
||||
"""Audio streaming endpoints: direct stream and HLS."""
|
||||
"""Audio streaming endpoint — direct stream with Range support."""
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.api.deps import StreamingServiceDep
|
||||
|
||||
router = APIRouter(prefix="/stream", tags=["streaming"])
|
||||
|
||||
|
||||
@router.get("/{track_id}")
|
||||
async def stream_track(track_id: uuid.UUID) -> Any: ...
|
||||
async def stream_track(
|
||||
track_id: uuid.UUID,
|
||||
service: StreamingServiceDep,
|
||||
range_header: Annotated[str | None, Header(alias="Range")] = None,
|
||||
) -> StreamingResponse:
|
||||
result = await service.open_stream(track_id, range_header)
|
||||
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(result.content_length),
|
||||
}
|
||||
|
||||
@router.get("/{track_id}/hls/playlist.m3u8")
|
||||
async def hls_playlist(track_id: uuid.UUID) -> Any: ...
|
||||
if result.is_partial:
|
||||
headers["Content-Range"] = f"bytes {result.start}-{result.end}/{result.total_size}"
|
||||
status_code = 206
|
||||
else:
|
||||
status_code = 200
|
||||
|
||||
|
||||
@router.get("/{track_id}/hls/{segment}")
|
||||
async def hls_segment(track_id: uuid.UUID, segment: str) -> Any: ...
|
||||
return StreamingResponse(
|
||||
result.stream,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
media_type=result.content_type,
|
||||
)
|
||||
|
||||
+17
-4
@@ -1,11 +1,24 @@
|
||||
"""Local file upload endpoint."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
|
||||
from app.api.deps import CurrentUser, UploadServiceDep
|
||||
from app.api.schemas.upload import UploadResponse
|
||||
|
||||
router = APIRouter(prefix="/upload", tags=["upload"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_file() -> Any: ...
|
||||
@router.post("", response_model=UploadResponse)
|
||||
async def upload_file(
|
||||
file: Annotated[UploadFile, File()],
|
||||
current_user: CurrentUser,
|
||||
service: UploadServiceDep,
|
||||
) -> UploadResponse:
|
||||
result = await service.handle_upload(upload=file, user=current_user)
|
||||
return UploadResponse(
|
||||
track_id=result.track_id,
|
||||
title=result.title,
|
||||
already_exists=result.already_exists,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""StreamingService — resolves a track and opens a byte-range stream."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.errors import NotFoundError, RangeNotSatisfiableError
|
||||
from app.domain.ports import FileStorage, TrackRepository
|
||||
|
||||
_FORMAT_CONTENT_TYPE: dict[str, str] = {
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
"aac": "audio/aac",
|
||||
"ogg": "audio/ogg",
|
||||
"opus": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"aiff": "audio/aiff",
|
||||
"aif": "audio/aiff",
|
||||
}
|
||||
|
||||
_RANGE_RE = re.compile(r"bytes=(\d+)-(\d*)")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamResult:
|
||||
stream: AsyncIterator[bytes]
|
||||
total_size: int
|
||||
content_length: int
|
||||
content_type: str
|
||||
start: int
|
||||
end: int
|
||||
is_partial: bool
|
||||
|
||||
|
||||
def _parse_range(header: str | None, total_size: int) -> tuple[int, int | None, bool]:
|
||||
"""Return (start, end, is_partial). Raises RangeNotSatisfiableError on invalid range."""
|
||||
if header is None:
|
||||
return 0, None, False
|
||||
|
||||
m = _RANGE_RE.fullmatch(header.strip())
|
||||
if not m:
|
||||
return 0, None, False # malformed → treat as absent per RFC 7233
|
||||
|
||||
start = int(m.group(1))
|
||||
end: int | None = int(m.group(2)) if m.group(2) else None
|
||||
|
||||
if start >= total_size:
|
||||
raise RangeNotSatisfiableError(total_size)
|
||||
|
||||
if end is not None:
|
||||
if end >= total_size:
|
||||
end = total_size - 1
|
||||
if end < start:
|
||||
raise RangeNotSatisfiableError(total_size)
|
||||
|
||||
return start, end, True
|
||||
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self, tracks: TrackRepository, storage: FileStorage) -> None:
|
||||
self._tracks = tracks
|
||||
self._storage = storage
|
||||
|
||||
async def open_stream(
|
||||
self,
|
||||
track_id: uuid.UUID,
|
||||
range_header: str | None,
|
||||
) -> StreamResult:
|
||||
track = await self._tracks.get_by_id(track_id)
|
||||
if track is None:
|
||||
raise NotFoundError("Track not found.")
|
||||
|
||||
stat = await self._storage.stat(track.file_path)
|
||||
total_size = stat.size
|
||||
content_type = stat.content_type or _FORMAT_CONTENT_TYPE.get(
|
||||
track.file_format.lower(), "application/octet-stream"
|
||||
)
|
||||
|
||||
start, end, is_partial = _parse_range(range_header, total_size)
|
||||
|
||||
stream, _ = await self._storage.open_range(track.file_path, start, end)
|
||||
|
||||
actual_end = end if end is not None else total_size - 1
|
||||
content_length = actual_end - start + 1
|
||||
|
||||
return StreamResult(
|
||||
stream=stream,
|
||||
total_size=total_size,
|
||||
content_length=content_length,
|
||||
content_type=content_type,
|
||||
start=start,
|
||||
end=actual_end,
|
||||
is_partial=is_partial,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""UploadService — handles user file uploads."""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
import anyio
|
||||
|
||||
from app.domain.entities.user import User
|
||||
from app.domain.ports import ArtistRepository, FileStorage, TrackRepository
|
||||
|
||||
|
||||
class UploadFileProtocol(Protocol):
|
||||
filename: str | None
|
||||
|
||||
async def read(self, size: int = -1) -> bytes: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
track_id: uuid.UUID
|
||||
title: str
|
||||
already_exists: bool
|
||||
|
||||
|
||||
async def _stream_to_temp(upload: UploadFileProtocol, dest: Path) -> tuple[str, int]:
|
||||
h = hashlib.sha256()
|
||||
size = 0
|
||||
async with await anyio.open_file(dest, "wb") as out:
|
||||
while True:
|
||||
chunk = await upload.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
await out.write(chunk)
|
||||
size += len(chunk)
|
||||
return h.hexdigest(), size
|
||||
|
||||
|
||||
class UploadService:
|
||||
def __init__(
|
||||
self,
|
||||
tracks: TrackRepository,
|
||||
artists: ArtistRepository,
|
||||
storage: FileStorage,
|
||||
tmp_dir: Path | None = None,
|
||||
) -> None:
|
||||
self._tracks = tracks
|
||||
self._artists = artists
|
||||
self._storage = storage
|
||||
self._tmp_dir = tmp_dir
|
||||
|
||||
async def handle_upload(
|
||||
self,
|
||||
*,
|
||||
upload: UploadFileProtocol,
|
||||
user: User,
|
||||
) -> UploadResult:
|
||||
filename = upload.filename or "unknown"
|
||||
ext = Path(filename).suffix.lower().lstrip(".") or "bin"
|
||||
title = Path(filename).stem or "Unknown"
|
||||
|
||||
fd, tmp_str = tempfile.mkstemp(
|
||||
suffix=f".{ext}",
|
||||
dir=str(self._tmp_dir) if self._tmp_dir else None,
|
||||
)
|
||||
tmp_path = Path(tmp_str)
|
||||
try:
|
||||
os.close(fd)
|
||||
sha256_hex, file_size = await _stream_to_temp(upload, tmp_path)
|
||||
|
||||
existing = await self._tracks.get_by_source("upload", sha256_hex)
|
||||
if existing is not None:
|
||||
return UploadResult(
|
||||
track_id=existing.id,
|
||||
title=existing.title,
|
||||
already_exists=True,
|
||||
)
|
||||
|
||||
track_id = uuid.uuid4()
|
||||
key = f"tracks/{str(track_id)[:2]}/{track_id}.{ext}"
|
||||
|
||||
await self._storage.save_file(key, tmp_path)
|
||||
try:
|
||||
artist = await self._artists.get_or_create("Unknown Artist")
|
||||
track = await self._tracks.add(
|
||||
id=track_id,
|
||||
title=title,
|
||||
artist_id=artist.id,
|
||||
file_path=key,
|
||||
file_format=ext,
|
||||
file_size=file_size,
|
||||
source="upload",
|
||||
source_id=sha256_hex,
|
||||
metadata_status="pending",
|
||||
added_by=user.id,
|
||||
)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._storage.delete(key)
|
||||
raise
|
||||
|
||||
# TODO(1D): enqueue metadata enrichment task
|
||||
|
||||
return UploadResult(
|
||||
track_id=track.id,
|
||||
title=track.title,
|
||||
already_exists=False,
|
||||
)
|
||||
finally:
|
||||
await anyio.Path(tmp_path).unlink(missing_ok=True)
|
||||
@@ -49,6 +49,15 @@ class Settings(BaseSettings):
|
||||
media_path: Path = Path("/data/media")
|
||||
transcode_cache_path: Path = Path("/data/transcode-cache")
|
||||
max_parallel_downloads: int = 2
|
||||
storage_backend: Literal["local", "s3"] = "local"
|
||||
upload_tmp_dir: Path | None = None
|
||||
|
||||
# -- S3 storage (deferred; set storage_backend="s3" to use) ----------
|
||||
s3_endpoint_url: str | None = None
|
||||
s3_bucket: str | None = None
|
||||
s3_region: str | None = None
|
||||
s3_access_key: SecretStr | None = None
|
||||
s3_secret_key: SecretStr | None = None
|
||||
|
||||
# -- external services (all optional; graceful degradation) ----------
|
||||
ml_service_url: str | None = None
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
"""File hashing utilities."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def sha256_of_file(path: Path) -> str:
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Domain entities and value objects — pure, framework-free."""
|
||||
|
||||
from app.domain.entities.storage import ObjectStat
|
||||
from app.domain.entities.track import Artist, Track
|
||||
from app.domain.entities.user import Credentials, User
|
||||
|
||||
__all__ = ["Credentials", "User"]
|
||||
__all__ = ["Artist", "Credentials", "ObjectStat", "Track", "User"]
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Value objects for file storage."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ObjectStat:
|
||||
size: int
|
||||
content_type: str | None
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Track and Artist domain entities."""
|
||||
|
||||
import datetime as dt
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Artist:
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
created_at: dt.datetime
|
||||
updated_at: dt.datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Track:
|
||||
id: uuid.UUID
|
||||
title: str
|
||||
artist_id: uuid.UUID
|
||||
file_path: str
|
||||
file_format: str
|
||||
file_size: int
|
||||
source: str
|
||||
source_id: str
|
||||
duration_seconds: int | None
|
||||
metadata_status: str
|
||||
created_at: dt.datetime
|
||||
updated_at: dt.datetime
|
||||
@@ -61,3 +61,19 @@ class DependencyUnavailableError(DomainError):
|
||||
"""
|
||||
|
||||
code = "dependency_unavailable"
|
||||
|
||||
|
||||
class StorageError(DomainError):
|
||||
"""File storage operation failed."""
|
||||
|
||||
code = "storage_error"
|
||||
|
||||
|
||||
class RangeNotSatisfiableError(DomainError):
|
||||
"""Requested byte range cannot be satisfied."""
|
||||
|
||||
code = "range_not_satisfiable"
|
||||
|
||||
def __init__(self, total_size: int) -> None:
|
||||
super().__init__("Requested range is not satisfiable.")
|
||||
self.total_size = total_size
|
||||
|
||||
+40
-1
@@ -7,9 +7,13 @@ are bound to these ports at the composition root (``app.api.deps``).
|
||||
|
||||
import datetime as dt
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from app.domain.entities import Credentials, User
|
||||
from app.domain.entities import Credentials, ObjectStat, User
|
||||
from app.domain.entities.track import Artist, Track
|
||||
from app.domain.tokens import IssuedToken, TokenClaims, TokenType
|
||||
|
||||
|
||||
@@ -56,3 +60,38 @@ class TokenService(Protocol):
|
||||
"""Verify signature + expiry and return claims. Raises
|
||||
:class:`~app.domain.errors.AuthenticationError` on any failure."""
|
||||
...
|
||||
|
||||
|
||||
class FileStorage(Protocol):
|
||||
async def save_file(self, key: str, src_path: Path) -> int: ...
|
||||
async def open_range(
|
||||
self, key: str, start: int, end: int | None
|
||||
) -> tuple[AsyncIterator[bytes], int]: ...
|
||||
async def stat(self, key: str) -> ObjectStat: ...
|
||||
async def exists(self, key: str) -> bool: ...
|
||||
async def delete(self, key: str) -> None: ...
|
||||
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]: ...
|
||||
|
||||
|
||||
class ArtistRepository(Protocol):
|
||||
async def get_or_create(self, name: str) -> Artist: ...
|
||||
|
||||
|
||||
class TrackRepository(Protocol):
|
||||
async def get_by_id(self, track_id: uuid.UUID) -> Track | None: ...
|
||||
async def get_by_source(self, source: str, source_id: str) -> Track | None: ...
|
||||
async def add(
|
||||
self,
|
||||
*,
|
||||
id: uuid.UUID,
|
||||
title: str,
|
||||
artist_id: uuid.UUID,
|
||||
file_path: str,
|
||||
file_format: str,
|
||||
file_size: int,
|
||||
source: str,
|
||||
source_id: str,
|
||||
metadata_status: str,
|
||||
added_by: uuid.UUID | None,
|
||||
) -> Track: ...
|
||||
async def delete(self, track_id: uuid.UUID) -> None: ...
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
"""SQLAlchemy repository adapters implementing the domain ports."""
|
||||
|
||||
from app.infrastructure.db.repositories.artist_repository import SqlAlchemyArtistRepository
|
||||
from app.infrastructure.db.repositories.refresh_token_repository import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
)
|
||||
from app.infrastructure.db.repositories.track_repository import SqlAlchemyTrackRepository
|
||||
from app.infrastructure.db.repositories.user_repository import SqlAlchemyUserRepository
|
||||
|
||||
__all__ = ["SqlAlchemyRefreshTokenRepository", "SqlAlchemyUserRepository"]
|
||||
__all__ = [
|
||||
"SqlAlchemyArtistRepository",
|
||||
"SqlAlchemyRefreshTokenRepository",
|
||||
"SqlAlchemyTrackRepository",
|
||||
"SqlAlchemyUserRepository",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Artist repository — adapter over ``AsyncSession``."""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.entities.track import Artist
|
||||
from app.infrastructure.db.models.artist import ArtistModel
|
||||
|
||||
|
||||
def _to_entity(row: ArtistModel) -> Artist:
|
||||
return Artist(
|
||||
id=row.id,
|
||||
name=row.name,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class SqlAlchemyArtistRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_or_create(self, name: str) -> Artist:
|
||||
row = (
|
||||
await self._session.execute(select(ArtistModel).where(ArtistModel.name == name))
|
||||
).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ArtistModel(name=name)
|
||||
self._session.add(row)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(row)
|
||||
return _to_entity(row)
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Track repository — adapter over ``AsyncSession``."""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.entities.track import Track
|
||||
from app.infrastructure.db.models.track import TrackModel
|
||||
|
||||
|
||||
def _to_entity(row: TrackModel) -> Track:
|
||||
return Track(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
artist_id=row.artist_id,
|
||||
file_path=row.file_path,
|
||||
file_format=row.file_format,
|
||||
file_size=row.file_size,
|
||||
source=row.source,
|
||||
source_id=row.source_id,
|
||||
duration_seconds=row.duration_seconds,
|
||||
metadata_status=row.metadata_status,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class SqlAlchemyTrackRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_by_id(self, track_id: uuid.UUID) -> Track | None:
|
||||
row = await self._session.get(TrackModel, track_id)
|
||||
return _to_entity(row) if row is not None else None
|
||||
|
||||
async def get_by_source(self, source: str, source_id: str) -> Track | None:
|
||||
row = (
|
||||
await self._session.execute(
|
||||
select(TrackModel).where(
|
||||
TrackModel.source == source,
|
||||
TrackModel.source_id == source_id,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return _to_entity(row) if row is not None else None
|
||||
|
||||
async def add(
|
||||
self,
|
||||
*,
|
||||
id: uuid.UUID,
|
||||
title: str,
|
||||
artist_id: uuid.UUID,
|
||||
file_path: str,
|
||||
file_format: str,
|
||||
file_size: int,
|
||||
source: str,
|
||||
source_id: str,
|
||||
metadata_status: str,
|
||||
added_by: uuid.UUID | None,
|
||||
) -> Track:
|
||||
row = TrackModel(
|
||||
id=id,
|
||||
title=title,
|
||||
artist_id=artist_id,
|
||||
file_path=file_path,
|
||||
file_format=file_format,
|
||||
file_size=file_size,
|
||||
source=source,
|
||||
source_id=source_id,
|
||||
metadata_status=metadata_status,
|
||||
added_by=added_by,
|
||||
)
|
||||
self._session.add(row)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(row)
|
||||
return _to_entity(row)
|
||||
|
||||
async def delete(self, track_id: uuid.UUID) -> None:
|
||||
row = await self._session.get(TrackModel, track_id)
|
||||
if row is not None:
|
||||
await self._session.delete(row)
|
||||
await self._session.flush()
|
||||
@@ -0,0 +1 @@
|
||||
"""File storage adapters."""
|
||||
@@ -0,0 +1,86 @@
|
||||
"""LocalFileStorage — stores files on the local filesystem."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
|
||||
from app.domain.entities.storage import ObjectStat
|
||||
from app.domain.errors import StorageError
|
||||
|
||||
_EXT_CONTENT_TYPE: dict[str, str] = {
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
"aac": "audio/aac",
|
||||
"ogg": "audio/ogg",
|
||||
"opus": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"aiff": "audio/aiff",
|
||||
"aif": "audio/aiff",
|
||||
}
|
||||
|
||||
|
||||
class LocalFileStorage:
|
||||
def __init__(self, media_path: Path) -> None:
|
||||
self._media_path = media_path
|
||||
|
||||
async def save_file(self, key: str, src_path: Path) -> int:
|
||||
dest = self._media_path / key
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
part = dest.with_suffix(dest.suffix + ".part")
|
||||
shutil.copyfile(str(src_path), str(part))
|
||||
os.replace(str(part), str(dest))
|
||||
return dest.stat().st_size
|
||||
|
||||
async def open_range(
|
||||
self, key: str, start: int, end: int | None
|
||||
) -> tuple[AsyncIterator[bytes], int]:
|
||||
path = self._media_path / key
|
||||
if not path.exists():
|
||||
raise StorageError(f"Object not found: {key}")
|
||||
total_size = path.stat().st_size
|
||||
|
||||
_start = start
|
||||
_end = end
|
||||
_total_size = total_size
|
||||
_path = path
|
||||
|
||||
async def _iter() -> AsyncGenerator[bytes]:
|
||||
async with await anyio.open_file(_path, "rb") as f:
|
||||
await f.seek(_start)
|
||||
remaining = (_end - _start + 1) if _end is not None else (_total_size - _start)
|
||||
while remaining > 0:
|
||||
chunk: bytes = await f.read(min(65536, remaining))
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
aiter: AsyncIterator[bytes] = _iter()
|
||||
return aiter, total_size
|
||||
|
||||
async def stat(self, key: str) -> ObjectStat:
|
||||
path = self._media_path / key
|
||||
if not path.exists():
|
||||
raise StorageError(f"Object not found: {key}")
|
||||
st = path.stat()
|
||||
ext = path.suffix.lower().lstrip(".")
|
||||
return ObjectStat(size=st.st_size, content_type=_EXT_CONTENT_TYPE.get(ext))
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
return (self._media_path / key).exists()
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
(self._media_path / key).unlink(missing_ok=True)
|
||||
|
||||
def as_local_path(self, key: str) -> AbstractAsyncContextManager[Path]:
|
||||
return self._as_local_path_cm(key)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _as_local_path_cm(self, key: str) -> AsyncGenerator[Path]:
|
||||
yield self._media_path / key
|
||||
@@ -0,0 +1,16 @@
|
||||
"""File storage provider — singleton factory."""
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.infrastructure.storage.local import LocalFileStorage
|
||||
|
||||
_storage: LocalFileStorage | None = None
|
||||
|
||||
|
||||
def get_file_storage() -> LocalFileStorage:
|
||||
global _storage
|
||||
if _storage is None:
|
||||
settings = get_settings()
|
||||
if settings.storage_backend == "s3":
|
||||
raise NotImplementedError("S3 storage not yet implemented.")
|
||||
_storage = LocalFileStorage(settings.media_path)
|
||||
return _storage
|
||||
Reference in New Issue
Block a user