feat: auth & admin
This commit is contained in:
+102
@@ -0,0 +1,102 @@
|
||||
"""In-memory port implementations for fast, DB-free unit tests."""
|
||||
|
||||
import datetime as dt
|
||||
import uuid
|
||||
from dataclasses import dataclass, replace
|
||||
|
||||
from app.domain.entities import Credentials, User
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Stored:
|
||||
user: User
|
||||
password_hash: str
|
||||
|
||||
|
||||
class InMemoryUserRepository:
|
||||
def __init__(self) -> None:
|
||||
self._by_id: dict[uuid.UUID, _Stored] = {}
|
||||
|
||||
async def get_by_id(self, user_id: uuid.UUID) -> User | None:
|
||||
stored = self._by_id.get(user_id)
|
||||
return stored.user if stored else None
|
||||
|
||||
async def get_credentials_by_username(self, username: str) -> Credentials | None:
|
||||
for stored in self._by_id.values():
|
||||
if stored.user.username == username:
|
||||
return Credentials(user=stored.user, password_hash=stored.password_hash)
|
||||
return None
|
||||
|
||||
async def add(self, *, username: str, password_hash: str, is_superuser: bool) -> User:
|
||||
now = dt.datetime.now(dt.UTC)
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username=username,
|
||||
is_superuser=is_superuser,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self._by_id[user.id] = _Stored(user=user, password_hash=password_hash)
|
||||
return user
|
||||
|
||||
async def list(self, *, limit: int, offset: int) -> list[User]:
|
||||
users = [s.user for s in self._by_id.values()]
|
||||
users.sort(key=lambda u: u.created_at)
|
||||
return users[offset : offset + limit]
|
||||
|
||||
async def set_password_hash(self, user_id: uuid.UUID, password_hash: str) -> None:
|
||||
self._by_id[user_id].password_hash = password_hash
|
||||
|
||||
async def set_superuser(self, user_id: uuid.UUID, is_superuser: bool) -> User:
|
||||
stored = self._by_id[user_id]
|
||||
stored.user = replace(stored.user, is_superuser=is_superuser)
|
||||
return stored.user
|
||||
|
||||
async def set_active(self, user_id: uuid.UUID, is_active: bool) -> User:
|
||||
stored = self._by_id[user_id]
|
||||
stored.user = replace(stored.user, is_active=is_active)
|
||||
return stored.user
|
||||
|
||||
async def count(self) -> int:
|
||||
return len(self._by_id)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Token:
|
||||
user_id: uuid.UUID
|
||||
token_hash: str
|
||||
expires_at: dt.datetime
|
||||
revoked_at: dt.datetime | None = None
|
||||
|
||||
|
||||
class InMemoryRefreshTokenRepository:
|
||||
def __init__(self) -> None:
|
||||
self._by_jti: dict[uuid.UUID, _Token] = {}
|
||||
|
||||
async def add(
|
||||
self,
|
||||
*,
|
||||
jti: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
token_hash: str,
|
||||
expires_at: dt.datetime,
|
||||
) -> None:
|
||||
self._by_jti[jti] = _Token(user_id=user_id, token_hash=token_hash, expires_at=expires_at)
|
||||
|
||||
async def is_valid(self, jti: uuid.UUID) -> bool:
|
||||
token = self._by_jti.get(jti)
|
||||
if token is None or token.revoked_at is not None:
|
||||
return False
|
||||
return token.expires_at > dt.datetime.now(dt.UTC)
|
||||
|
||||
async def revoke(self, jti: uuid.UUID) -> None:
|
||||
token = self._by_jti.get(jti)
|
||||
if token and token.revoked_at is None:
|
||||
token.revoked_at = dt.datetime.now(dt.UTC)
|
||||
|
||||
async def revoke_all_for_user(self, user_id: uuid.UUID) -> None:
|
||||
now = dt.datetime.now(dt.UTC)
|
||||
for token in self._by_jti.values():
|
||||
if token.user_id == user_id and token.revoked_at is None:
|
||||
token.revoked_at = now
|
||||
@@ -0,0 +1,169 @@
|
||||
"""Integration tests for the auth + admin HTTP surface.
|
||||
|
||||
These require a reachable Postgres (the schema is created via metadata). When
|
||||
no DB is available they *skip* — preserving the project rule that the test
|
||||
suite never hard-requires a running database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
from app.core.security import Argon2PasswordHasher
|
||||
from app.infrastructure.db import Base, get_engine, session_scope
|
||||
from app.infrastructure.db.repositories import (
|
||||
SqlAlchemyRefreshTokenRepository,
|
||||
SqlAlchemyUserRepository,
|
||||
)
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
_db_reachable_cache: bool | None = None
|
||||
|
||||
|
||||
async def _db_reachable() -> bool:
|
||||
# Probe once per session (cached): bounded so the suite never hangs when
|
||||
# nothing (or a half-open socket) is on the DB port — mirrors the
|
||||
# readiness-probe rule (never hang).
|
||||
global _db_reachable_cache
|
||||
if _db_reachable_cache is not None:
|
||||
return _db_reachable_cache
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(3):
|
||||
async with get_engine().connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
_db_reachable_cache = True
|
||||
except Exception:
|
||||
_db_reachable_cache = False
|
||||
return _db_reachable_cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def api() -> AsyncIterator[AsyncClient]:
|
||||
if not await _db_reachable():
|
||||
pytest.skip("Postgres not reachable — integration test skipped.")
|
||||
|
||||
# Fresh schema for the test run.
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Seed an admin directly through the service layer.
|
||||
from app.application.user_service import UserService
|
||||
|
||||
async with session_scope() as session:
|
||||
await UserService(
|
||||
users=SqlAlchemyUserRepository(session),
|
||||
refresh_tokens=SqlAlchemyRefreshTokenRepository(session),
|
||||
hasher=Argon2PasswordHasher(),
|
||||
).create_user(username="admin", password="adminpass1", is_superuser=True)
|
||||
|
||||
from app.main import create_app
|
||||
from asgi_lifespan import LifespanManager
|
||||
|
||||
app = create_app()
|
||||
async with LifespanManager(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
async with get_engine().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
async def _login(api: AsyncClient, username: str, password: str) -> tuple[str, str]:
|
||||
resp = await api.post("/api/v1/auth/login", json={"username": username, "password": password})
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
return body["access_token"], body["refresh_token"]
|
||||
|
||||
|
||||
async def test_login_and_me(api: AsyncClient) -> None:
|
||||
access, _ = await _login(api, "admin", "adminpass1")
|
||||
resp = await api.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {access}"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["username"] == "admin"
|
||||
assert resp.json()["is_superuser"] is True
|
||||
|
||||
|
||||
async def test_login_bad_credentials(api: AsyncClient) -> None:
|
||||
resp = await api.post("/api/v1/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_me_requires_token(api: AsyncClient) -> None:
|
||||
resp = await api.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
async def test_refresh_rotation(api: AsyncClient) -> None:
|
||||
_, refresh = await _login(api, "admin", "adminpass1")
|
||||
resp = await api.post("/api/v1/auth/refresh", json={"refresh_token": refresh})
|
||||
assert resp.status_code == 200
|
||||
new_refresh = resp.json()["refresh_token"]
|
||||
assert new_refresh != refresh
|
||||
|
||||
# Old refresh is revoked after rotation.
|
||||
reuse = await api.post("/api/v1/auth/refresh", json={"refresh_token": refresh})
|
||||
assert reuse.status_code == 401
|
||||
|
||||
|
||||
async def test_logout_revokes(api: AsyncClient) -> None:
|
||||
_, refresh = await _login(api, "admin", "adminpass1")
|
||||
out = await api.post("/api/v1/auth/logout", json={"refresh_token": refresh})
|
||||
assert out.status_code == 204
|
||||
reuse = await api.post("/api/v1/auth/refresh", json={"refresh_token": refresh})
|
||||
assert reuse.status_code == 401
|
||||
|
||||
|
||||
async def test_admin_creates_user_and_nonadmin_forbidden(api: AsyncClient) -> None:
|
||||
admin_access, _ = await _login(api, "admin", "adminpass1")
|
||||
admin_headers = {"Authorization": f"Bearer {admin_access}"}
|
||||
|
||||
created = await api.post(
|
||||
"/api/v1/admin/users",
|
||||
headers=admin_headers,
|
||||
json={"username": "carol", "password": "carolpass1"},
|
||||
)
|
||||
assert created.status_code == 201, created.text
|
||||
assert created.json()["is_superuser"] is False
|
||||
|
||||
# Non-admin cannot reach admin routes.
|
||||
user_access, _ = await _login(api, "carol", "carolpass1")
|
||||
forbidden = await api.get(
|
||||
"/api/v1/admin/users", headers={"Authorization": f"Bearer {user_access}"}
|
||||
)
|
||||
assert forbidden.status_code == 403
|
||||
|
||||
|
||||
async def test_admin_create_duplicate_conflicts(api: AsyncClient) -> None:
|
||||
admin_access, _ = await _login(api, "admin", "adminpass1")
|
||||
headers = {"Authorization": f"Bearer {admin_access}"}
|
||||
payload = {"username": "dave", "password": "davepass12"}
|
||||
|
||||
first = await api.post("/api/v1/admin/users", headers=headers, json=payload)
|
||||
assert first.status_code == 201
|
||||
dup = await api.post("/api/v1/admin/users", headers=headers, json=payload)
|
||||
assert dup.status_code == 409
|
||||
|
||||
|
||||
async def test_deactivated_user_cannot_login(api: AsyncClient) -> None:
|
||||
admin_access, _ = await _login(api, "admin", "adminpass1")
|
||||
headers = {"Authorization": f"Bearer {admin_access}"}
|
||||
created = await api.post(
|
||||
"/api/v1/admin/users",
|
||||
headers=headers,
|
||||
json={"username": "erin", "password": "erinpass12"},
|
||||
)
|
||||
user_id = created.json()["id"]
|
||||
|
||||
deactivate = await api.delete(f"/api/v1/admin/users/{user_id}", headers=headers)
|
||||
assert deactivate.status_code == 200
|
||||
assert deactivate.json()["is_active"] is False
|
||||
|
||||
resp = await api.post("/api/v1/auth/login", json={"username": "erin", "password": "erinpass12"})
|
||||
assert resp.status_code == 401
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Unit tests for AuthService using in-memory ports."""
|
||||
|
||||
import pytest
|
||||
from app.application.auth_service import AuthService
|
||||
from app.application.user_service import UserService
|
||||
from app.core.config import Settings
|
||||
from app.core.security import Argon2PasswordHasher, JwtTokenService
|
||||
from app.domain.errors import AuthenticationError
|
||||
|
||||
from tests.fakes import InMemoryRefreshTokenRepository, InMemoryUserRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env() -> tuple[AuthService, UserService, InMemoryUserRepository]:
|
||||
users = InMemoryUserRepository()
|
||||
refresh = InMemoryRefreshTokenRepository()
|
||||
hasher = Argon2PasswordHasher()
|
||||
tokens = JwtTokenService(Settings(jwt_secret="svc-test-secret"))
|
||||
auth = AuthService(users=users, refresh_tokens=refresh, hasher=hasher, tokens=tokens)
|
||||
user_svc = UserService(users=users, refresh_tokens=refresh, hasher=hasher)
|
||||
return auth, user_svc, users
|
||||
|
||||
|
||||
async def test_login_success_then_authenticate(
|
||||
env: tuple[AuthService, UserService, object],
|
||||
) -> None:
|
||||
auth, user_svc, _ = env
|
||||
created = await user_svc.create_user(username="alice", password="password123")
|
||||
|
||||
pair = await auth.login("alice", "password123")
|
||||
user = await auth.authenticate_access(pair.access.encoded)
|
||||
assert user.id == created.id
|
||||
assert user.username == "alice"
|
||||
|
||||
|
||||
async def test_login_wrong_password(env: tuple[AuthService, UserService, object]) -> None:
|
||||
auth, user_svc, _ = env
|
||||
await user_svc.create_user(username="alice", password="password123")
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.login("alice", "nope")
|
||||
|
||||
|
||||
async def test_login_unknown_user(env: tuple[AuthService, UserService, object]) -> None:
|
||||
auth, _, _ = env
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.login("ghost", "whatever")
|
||||
|
||||
|
||||
async def test_login_inactive_user(env: tuple[AuthService, UserService, object]) -> None:
|
||||
auth, user_svc, _ = env
|
||||
user = await user_svc.create_user(username="alice", password="password123")
|
||||
await user_svc.set_active(user.id, is_active=False)
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.login("alice", "password123")
|
||||
|
||||
|
||||
async def test_refresh_rotates_and_invalidates_old(
|
||||
env: tuple[AuthService, UserService, object],
|
||||
) -> None:
|
||||
auth, user_svc, _ = env
|
||||
await user_svc.create_user(username="alice", password="password123")
|
||||
pair = await auth.login("alice", "password123")
|
||||
|
||||
new_pair = await auth.refresh(pair.refresh.encoded)
|
||||
assert new_pair.refresh.encoded != pair.refresh.encoded
|
||||
|
||||
# Old refresh token is now revoked (rotation) — reuse must fail.
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.refresh(pair.refresh.encoded)
|
||||
|
||||
# New one still works.
|
||||
await auth.refresh(new_pair.refresh.encoded)
|
||||
|
||||
|
||||
async def test_access_token_not_accepted_as_refresh(
|
||||
env: tuple[AuthService, UserService, object],
|
||||
) -> None:
|
||||
auth, user_svc, _ = env
|
||||
await user_svc.create_user(username="alice", password="password123")
|
||||
pair = await auth.login("alice", "password123")
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.refresh(pair.access.encoded)
|
||||
|
||||
|
||||
async def test_logout_revokes_refresh(env: tuple[AuthService, UserService, object]) -> None:
|
||||
auth, user_svc, _ = env
|
||||
await user_svc.create_user(username="alice", password="password123")
|
||||
pair = await auth.login("alice", "password123")
|
||||
|
||||
await auth.logout(pair.refresh.encoded)
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.refresh(pair.refresh.encoded)
|
||||
|
||||
|
||||
async def test_logout_ignores_garbage(env: tuple[AuthService, UserService, object]) -> None:
|
||||
auth, _, _ = env
|
||||
await auth.logout("not-a-jwt") # must not raise
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Unit tests for the security adapters (no DB, no network)."""
|
||||
|
||||
import datetime as dt
|
||||
import uuid
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from app.core.config import Settings
|
||||
from app.core.security import Argon2PasswordHasher, JwtTokenService
|
||||
from app.domain.errors import AuthenticationError
|
||||
from app.domain.tokens import TokenType
|
||||
|
||||
|
||||
def _settings(**overrides: object) -> Settings:
|
||||
base: dict[str, object] = {"jwt_secret": "unit-test-secret", "access_token_ttl_seconds": 900}
|
||||
base.update(overrides)
|
||||
return Settings(**base) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_password_hash_roundtrip() -> None:
|
||||
hasher = Argon2PasswordHasher()
|
||||
hashed = hasher.hash("correct horse battery staple")
|
||||
assert hashed != "correct horse battery staple"
|
||||
|
||||
valid, updated = hasher.verify_and_update("correct horse battery staple", hashed)
|
||||
assert valid is True
|
||||
assert updated is None # fresh hash, no rehash needed
|
||||
|
||||
wrong, _ = hasher.verify_and_update("wrong password", hashed)
|
||||
assert wrong is False
|
||||
|
||||
|
||||
def test_jwt_issue_and_decode_roundtrip() -> None:
|
||||
svc = JwtTokenService(_settings())
|
||||
subject = uuid.uuid4()
|
||||
|
||||
issued = svc.issue(subject=subject, token_type=TokenType.ACCESS)
|
||||
claims = svc.decode(issued.encoded)
|
||||
|
||||
assert claims.subject == subject
|
||||
assert claims.token_type is TokenType.ACCESS
|
||||
assert claims.jti == issued.jti
|
||||
|
||||
|
||||
def test_jwt_rejects_tampered_token() -> None:
|
||||
svc = JwtTokenService(_settings())
|
||||
issued = svc.issue(subject=uuid.uuid4(), token_type=TokenType.ACCESS)
|
||||
tampered = issued.encoded[:-2] + ("aa" if issued.encoded[-2:] != "aa" else "bb")
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
svc.decode(tampered)
|
||||
|
||||
|
||||
def test_jwt_rejects_wrong_secret() -> None:
|
||||
issuer = JwtTokenService(_settings(jwt_secret="secret-a"))
|
||||
verifier = JwtTokenService(_settings(jwt_secret="secret-b"))
|
||||
issued = issuer.issue(subject=uuid.uuid4(), token_type=TokenType.ACCESS)
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
verifier.decode(issued.encoded)
|
||||
|
||||
|
||||
def test_jwt_rejects_expired_token() -> None:
|
||||
settings = _settings()
|
||||
secret = settings.jwt_secret.get_secret_value()
|
||||
expired = jwt.encode(
|
||||
{
|
||||
"sub": str(uuid.uuid4()),
|
||||
"type": "access",
|
||||
"jti": str(uuid.uuid4()),
|
||||
"iat": int((dt.datetime.now(dt.UTC) - dt.timedelta(hours=2)).timestamp()),
|
||||
"exp": int((dt.datetime.now(dt.UTC) - dt.timedelta(hours=1)).timestamp()),
|
||||
},
|
||||
secret,
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
with pytest.raises(AuthenticationError):
|
||||
JwtTokenService(settings).decode(expired)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Unit tests for UserService using in-memory ports."""
|
||||
|
||||
import pytest
|
||||
from app.application.auth_service import AuthService
|
||||
from app.application.user_service import UserService
|
||||
from app.core.config import Settings
|
||||
from app.core.security import Argon2PasswordHasher, JwtTokenService
|
||||
from app.domain.errors import AlreadyExistsError, AuthenticationError, NotFoundError
|
||||
|
||||
from tests.fakes import InMemoryRefreshTokenRepository, InMemoryUserRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env() -> tuple[UserService, AuthService]:
|
||||
users = InMemoryUserRepository()
|
||||
refresh = InMemoryRefreshTokenRepository()
|
||||
hasher = Argon2PasswordHasher()
|
||||
tokens = JwtTokenService(Settings(jwt_secret="u-test-secret"))
|
||||
user_svc = UserService(users=users, refresh_tokens=refresh, hasher=hasher)
|
||||
auth = AuthService(users=users, refresh_tokens=refresh, hasher=hasher, tokens=tokens)
|
||||
return user_svc, auth
|
||||
|
||||
|
||||
async def test_create_user_duplicate_username(env: tuple[UserService, AuthService]) -> None:
|
||||
user_svc, _ = env
|
||||
await user_svc.create_user(username="bob", password="password123")
|
||||
with pytest.raises(AlreadyExistsError):
|
||||
await user_svc.create_user(username="bob", password="another-one")
|
||||
|
||||
|
||||
async def test_get_unknown_user_raises(env: tuple[UserService, AuthService]) -> None:
|
||||
import uuid
|
||||
|
||||
user_svc, _ = env
|
||||
with pytest.raises(NotFoundError):
|
||||
await user_svc.get_user(uuid.uuid4())
|
||||
|
||||
|
||||
async def test_change_password_requires_current(env: tuple[UserService, AuthService]) -> None:
|
||||
user_svc, auth = env
|
||||
user = await user_svc.create_user(username="bob", password="password123")
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
await user_svc.change_password(
|
||||
user.id, current_password="wrong", new_password="newpassword1"
|
||||
)
|
||||
|
||||
await user_svc.change_password(
|
||||
user.id, current_password="password123", new_password="newpassword1"
|
||||
)
|
||||
# New password works, old one no longer.
|
||||
await auth.login("bob", "newpassword1")
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.login("bob", "password123")
|
||||
|
||||
|
||||
async def test_change_password_revokes_sessions(env: tuple[UserService, AuthService]) -> None:
|
||||
user_svc, auth = env
|
||||
user = await user_svc.create_user(username="bob", password="password123")
|
||||
pair = await auth.login("bob", "password123")
|
||||
|
||||
await user_svc.change_password(
|
||||
user.id, current_password="password123", new_password="newpassword1"
|
||||
)
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.refresh(pair.refresh.encoded)
|
||||
|
||||
|
||||
async def test_reset_password_revokes_sessions(env: tuple[UserService, AuthService]) -> None:
|
||||
user_svc, auth = env
|
||||
user = await user_svc.create_user(username="bob", password="password123")
|
||||
pair = await auth.login("bob", "password123")
|
||||
|
||||
await user_svc.reset_password(user.id, new_password="adminset12")
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.refresh(pair.refresh.encoded)
|
||||
await auth.login("bob", "adminset12")
|
||||
|
||||
|
||||
async def test_deactivate_revokes_sessions(env: tuple[UserService, AuthService]) -> None:
|
||||
user_svc, auth = env
|
||||
user = await user_svc.create_user(username="bob", password="password123")
|
||||
pair = await auth.login("bob", "password123")
|
||||
|
||||
deactivated = await user_svc.deactivate(user.id)
|
||||
assert deactivated.is_active is False
|
||||
with pytest.raises(AuthenticationError):
|
||||
await auth.refresh(pair.refresh.encoded)
|
||||
Reference in New Issue
Block a user