105 lines
3.2 KiB
Python
105 lines
3.2 KiB
Python
from fastapi import Depends, HTTPException, status
|
|
from typing import Annotated
|
|
from sqlalchemy.orm import Session
|
|
from uuid import UUID
|
|
import redis
|
|
import asyncio
|
|
|
|
from ...dependencies import get_db, get_pubsub, get_redis
|
|
from ...db import models
|
|
|
|
from ..auth import services as auth_services
|
|
from ..auth import schemas as auth_schemas
|
|
|
|
from . import schemas
|
|
|
|
|
|
def get_queue_by_id(queue_id: UUID, db: Session) -> models.Queue:
|
|
q = db.query(models.Queue).filter(models.Queue.id == queue_id).first()
|
|
return q
|
|
|
|
|
|
def get_user_queues(
|
|
current_user: Annotated[auth_schemas.User, Depends(auth_services.get_current_user)]
|
|
) -> list[schemas.QueueInDb]:
|
|
return [schemas.QueueInDb.model_validate(q) for q in current_user.owns_queues]
|
|
|
|
|
|
def create_queue(
|
|
new_queue: schemas.Queue,
|
|
current_user: auth_schemas.UserInDB,
|
|
db: Session,
|
|
) -> schemas.QueueInDb:
|
|
q = models.Queue(
|
|
name=new_queue.name, description=new_queue.description, owner_id=current_user.id
|
|
)
|
|
db.add(q)
|
|
db.commit()
|
|
return schemas.QueueInDb.model_validate(q)
|
|
|
|
|
|
def get_detailed_queue(
|
|
queue_id: UUID,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
) -> schemas.QueueDetail:
|
|
q = db.query(models.Queue).filter(models.Queue.id == queue_id).first()
|
|
if q:
|
|
return schemas.QueueDetail(
|
|
id=q.id,
|
|
name=q.name,
|
|
description=q.description,
|
|
participants=schemas.ParticipantInfo(
|
|
total=q.users.count(),
|
|
remaining=q.users.filter(models.QueueUser.passed == False).count(),
|
|
users_list=q.users.filter(models.QueueUser.passed == False).order_by(
|
|
models.QueueUser.position.asc()
|
|
),
|
|
),
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Not Found",
|
|
)
|
|
|
|
|
|
async def join_queue(
|
|
queue_id: UUID,
|
|
client: Annotated[auth_schemas.AnonUser, Depends(auth_services.get_anon_user)],
|
|
db: Annotated[Session, Depends(get_db)],
|
|
r: Annotated[redis.client.Redis, Depends(get_redis)],
|
|
) -> schemas.QueueUser:
|
|
q = get_queue_by_id(queue_id, db)
|
|
if q:
|
|
if not q.users.filter(models.QueueUser.user_id == client.id).first():
|
|
last_qu = q.users.order_by(models.QueueUser.position.desc()).first()
|
|
position = last_qu.position + 1 if last_qu else 0
|
|
new_qu = models.QueueUser(
|
|
user_id=client.id, queue_id=q.id, position=position
|
|
)
|
|
db.add(new_qu)
|
|
db.commit()
|
|
await r.publish(str(queue_id), "updated")
|
|
return new_qu
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Already joined",
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Not Found",
|
|
)
|
|
|
|
|
|
async def set_queue_listener(
|
|
queue_id: UUID,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
ps: Annotated[redis.client.PubSub, Depends(get_pubsub)],
|
|
) -> schemas.QueueDetail:
|
|
await ps.subscribe(str(queue_id))
|
|
async for m in ps.listen():
|
|
if m.get("data", None) == b"updated":
|
|
break
|
|
await ps.unsubscribe()
|
|
new_queue = get_detailed_queue(queue_id=queue_id, db=db)
|
|
return new_queue
|