async (not working tbh)

This commit is contained in:
2024-04-14 14:45:01 +03:00
parent 17e0f78ecf
commit 033dbc538e
11 changed files with 95 additions and 50 deletions

View File

@ -1,4 +1,4 @@
from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import os import os
@ -9,11 +9,11 @@ POSTGRES_DB = os.environ.get("POSTGRES_DB", "db")
POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "postgres") POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "postgres")
SQLALCHEMY_DATABASE_URL = ( SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DB}"
f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DB}"
engine = create_async_engine(SQLALCHEMY_DATABASE_URL)
async_session = sessionmaker(
autocommit=False, class_=AsyncSession, autoflush=False, bind=engine
) )
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()

View File

@ -1,9 +1,11 @@
from .db.database import SessionLocal from sqlalchemy.ext.asyncio import AsyncSession
from .db.database import async_session
def get_db(): async def get_db() -> AsyncSession:
db = SessionLocal() async with async_session() as session:
try: try:
yield db yield session
finally: finally:
db.close() session.close()

View File

@ -2,7 +2,7 @@ from typing import Union
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from .db import models from .db import models
from .db.database import SessionLocal, engine from .db.database import engine
from .dependencies import get_db from .dependencies import get_db
from .views.auth.api import router as auth_router from .views.auth.api import router as auth_router
@ -10,13 +10,19 @@ from .views.queue.api import router as queue_router
from .views.news.api import router as news_router from .views.news.api import router as news_router
app = FastAPI(dependencies=[Depends(get_db)]) app = FastAPI(dependencies=[Depends(get_db)])
models.Base.metadata.create_all(bind=engine)
app.include_router(queue_router) app.include_router(queue_router)
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(news_router) app.include_router(news_router)
@app.on_event("startup")
async def init_tables():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
@app.get("/") @app.get("/")
async def read_root(): async def read_root():
return {"message": "OK"} return {"message": "OK"}

View File

@ -27,7 +27,7 @@ async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
) -> schemas.Token: ) -> schemas.Token:
user = services.authenticate_user(db, form_data.username, form_data.password) user = await services.authenticate_user(db, form_data.username, form_data.password)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
@ -46,7 +46,7 @@ async def register(
user_data: schemas.UserRegister, user_data: schemas.UserRegister,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
) -> schemas.User: ) -> schemas.User:
user = services.get_user_by_username(db, user_data.username) user = await services.get_user_by_username(db, user_data.username)
if user: if user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -59,7 +59,7 @@ async def register(
detail="Passwords do not match", detail="Passwords do not match",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
user = services.create_user(db=db, user_data=user_data) user = await services.create_user(db=db, user_data=user_data)
return user return user

View File

@ -27,3 +27,11 @@ class Token(BaseModel):
class TokenData(BaseModel): class TokenData(BaseModel):
username: Union[str, None] = None username: Union[str, None] = None
class AnonUser(BaseModel):
id: UUID
name: str
class Config:
from_attributes = True

View File

@ -1,6 +1,9 @@
from fastapi import status, HTTPException, Depends from fastapi import status, HTTPException, Depends, Header
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
# from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from jose import JWTError, jwt from jose import JWTError, jwt
from typing import Annotated, Union from typing import Annotated, Union
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
@ -24,16 +27,18 @@ def get_password_hash(password) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)
def get_user_by_id(db: Session, user_id: uuid.uuid4) -> models.User: async def get_user_by_id(db: AsyncSession, user_id: uuid.uuid4) -> models.User:
return db.query(models.User).filter(models.User.id == user_id).first() u = await db.execute(select(models.User).filter(models.User.id == user_id))
return u.scalar_one_or_none()
def get_user_by_username(db: Session, username: int) -> models.User: async def get_user_by_username(db: AsyncSession, username: int) -> models.User:
return db.query(models.User).filter(models.User.username == username).first() u = await db.execute(select(models.User).filter(models.User.username == username))
return u.scalar_one_or_none()
def authenticate_user(db: Session, username: str, password: str): async def authenticate_user(db: AsyncSession, username: str, password: str):
user = get_user_by_username(db, username) user = await get_user_by_username(db, username)
if not user: if not user:
return False return False
if not verify_password(password, user.hashed_password): if not verify_password(password, user.hashed_password):
@ -54,20 +59,22 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None
return encoded_jwt return encoded_jwt
def create_user(db: Session, user_data: schemas.UserRegister) -> schemas.UserInDB: async def create_user(
db: AsyncSession, user_data: schemas.UserRegister
) -> schemas.UserInDB:
user = models.User( user = models.User(
username=user_data.username, username=user_data.username,
name=user_data.name, name=user_data.name,
hashed_password=get_password_hash(user_data.password), hashed_password=get_password_hash(user_data.password),
) )
db.add(user) db.add(user)
db.commit() await db.commit()
return schemas.UserInDB.model_validate(user) return schemas.UserInDB.model_validate(user)
async def get_current_user( async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], token: Annotated[str, Depends(oauth2_scheme)],
db: Annotated[Session, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
) -> schemas.UserInDB: ) -> schemas.UserInDB:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -84,7 +91,7 @@ async def get_current_user(
token_data = schemas.TokenData(username=username) token_data = schemas.TokenData(username=username)
except JWTError: except JWTError:
raise credentials_exception raise credentials_exception
user = get_user_by_username(db, username=token_data.username) user = await get_user_by_username(db, username=token_data.username)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
return user return user
@ -92,7 +99,7 @@ async def get_current_user(
async def get_current_user_or_none( async def get_current_user_or_none(
token: Annotated[str, Depends(oauth2_scheme)], token: Annotated[str, Depends(oauth2_scheme)],
db: Annotated[Session, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
) -> Union[schemas.UserInDB, None]: ) -> Union[schemas.UserInDB, None]:
try: try:
payload = jwt.decode( payload = jwt.decode(
@ -104,7 +111,7 @@ async def get_current_user_or_none(
token_data = schemas.TokenData(username=username) token_data = schemas.TokenData(username=username)
except JWTError: except JWTError:
return None return None
user = get_user_by_username(db, username=token_data.username) user = await get_user_by_username(db, username=token_data.username)
return user return user
@ -114,3 +121,24 @@ async def get_current_active_user(
if not current_user.is_active: if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return current_user return current_user
async def create_anon_user(
db: Annotated[AsyncSession, Depends(get_db)]
) -> schemas.AnonUser:
u = models.AnonymousUser()
db.add(u)
await db.commit()
return schemas.AnonUser.model_validate(u)
async def get_anon_user(
db: Annotated[AsyncSession, Depends(get_db)],
device_id: Annotated[Union[str, None], Header()] = None,
) -> schemas.AnonUser:
if device_id:
u = await db.execute(
select(models.AnonymousUser).filter(models.AnonymousUser.id == device_id)
)
return schemas.AnonUser.model_validate(u.scalar_one_or_none())
return await create_anon_user(db)

View File

@ -39,7 +39,8 @@ async def create_news(
current_user: Annotated[auth_schemas.User, Depends(auth_services.get_current_user)], current_user: Annotated[auth_schemas.User, Depends(auth_services.get_current_user)],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
) -> schemas.NewsInDb: ) -> schemas.NewsInDb:
return services.create_news(news=news, current_user=current_user, db=db) n = await services.create_news(news=news, current_user=current_user, db=db)
return n
@router.post("/{news_id}/tap") @router.post("/{news_id}/tap")

View File

@ -1,6 +1,7 @@
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from typing import Annotated from typing import Annotated
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from uuid import UUID from uuid import UUID
from ...dependencies import get_db from ...dependencies import get_db
@ -12,24 +13,23 @@ from ..auth import schemas as auth_schemas
from . import schemas from . import schemas
def get_news( async def get_news(
db: Annotated[Session, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
) -> list[schemas.NewsInDb]: ) -> list[schemas.NewsInDb]:
return [ news = await db.execute(select(models.News).order_by(models.News.created.desc()))
schemas.NewsInDb.model_validate(n) return [schemas.NewsInDb.model_validate(n) for n in news.scalars().all()]
for n in db.query(models.News).order_by(models.News.created.desc()).all()
]
def create_news( async def create_news(
news: schemas.CreateNews, news: schemas.CreateNews,
current_user: auth_schemas.UserInDB, current_user: auth_schemas.UserInDB,
db: Session, db: AsyncSession,
) -> schemas.NewsInDb: ) -> schemas.NewsInDb:
if current_user.username == "admin": if current_user.username == "admin":
n = models.News(title=news.title, content=news.content) n = models.News(title=news.title, content=news.content)
db.add(n) db.add(n)
db.commit() await db.commit()
print(f"\n\n{n.title}\n\n", flush=True)
return schemas.NewsInDb.model_validate(n) return schemas.NewsInDb.model_validate(n)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -38,7 +38,7 @@ def create_news(
) )
def tap_news(news_id: UUID, db: Session): async def tap_news(news_id: UUID, db: AsyncSession):
n = db.query(models.News).filter(models.News.id == news_id).first() n = db.query(models.News).filter(models.News.id == news_id).first()
if n: if n:
setattr(n, "taps", n.taps + 1) setattr(n, "taps", n.taps + 1)

View File

@ -12,13 +12,13 @@ from ..auth import schemas as auth_schemas
from . import schemas from . import schemas
def get_user_queues( async def get_user_queues(
current_user: Annotated[auth_schemas.User, Depends(auth_services.get_current_user)] current_user: Annotated[auth_schemas.User, Depends(auth_services.get_current_user)]
) -> list[schemas.QueueInDb]: ) -> list[schemas.QueueInDb]:
return [schemas.QueueInDb.model_validate(q) for q in current_user.owns_queues] return [schemas.QueueInDb.model_validate(q) for q in current_user.owns_queues]
def create_queue( async def create_queue(
new_queue: schemas.Queue, new_queue: schemas.Queue,
current_user: auth_schemas.UserInDB, current_user: auth_schemas.UserInDB,
db: Session, db: Session,
@ -31,7 +31,7 @@ def create_queue(
return schemas.QueueInDb.model_validate(q) return schemas.QueueInDb.model_validate(q)
def get_detailed_queue( async def get_detailed_queue(
queue_id: UUID, queue_id: UUID,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
) -> schemas.QueueDetail: ) -> schemas.QueueDetail:

View File

@ -2,6 +2,6 @@ fastapi[all]
uvicorn uvicorn
pydantic pydantic
sqlalchemy sqlalchemy
psycopg2-binary asyncpg
python-jose[cryptography] python-jose[cryptography]
passlib[all] passlib[all]

View File

@ -8,7 +8,7 @@ export default defineConfig({
port: 3000, port: 3000,
}, },
html: { html: {
favicon: "./static/favicon-32x32.png", favicon: "./static/android-chrome-512x512.png",
title: "queueful!", title: "queueful!",
template: "./static/index.html", template: "./static/index.html",
}, },