Files
bdbot/bot/database.py
2026-01-28 15:53:27 +03:00

99 lines
3.3 KiB
Python

"""Database models and session management."""
from typing import AsyncGenerator
from sqlalchemy import create_engine, BigInteger, String, Integer, Boolean, ForeignKey, Column
from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from bot.config import Config
from bot.logger import get_logger
logger = get_logger(__name__)
Base = declarative_base()
class User(Base):
"""User model - stores user information and birthday."""
__tablename__ = "users"
user_id = Column(BigInteger, primary_key=True)
username = Column(String, nullable=True)
first_name = Column(String, nullable=False)
birthday_day = Column(Integer, nullable=False)
birthday_month = Column(Integer, nullable=False)
birthday_year = Column(Integer, nullable=True)
preference_theme = Column(String, nullable=False)
# Relationships
chats = relationship("UserChat", back_populates="user", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(user_id={self.user_id}, first_name={self.first_name})>"
class Chat(Base):
"""Chat model - stores chat information."""
__tablename__ = "chats"
chat_id = Column(BigInteger, primary_key=True)
chat_title = Column(String, nullable=False)
bot_is_admin = Column(Boolean, default=False, nullable=False)
# Relationships
users = relationship("UserChat", back_populates="chat", cascade="all, delete-orphan")
def __repr__(self):
return f"<Chat(chat_id={self.chat_id}, chat_title={self.chat_title})>"
class UserChat(Base):
"""Many-to-many relationship between users and chats."""
__tablename__ = "user_chats"
user_id = Column(BigInteger, ForeignKey("users.user_id"), primary_key=True)
chat_id = Column(BigInteger, ForeignKey("chats.chat_id"), primary_key=True)
# Relationships
user = relationship("User", back_populates="chats")
chat = relationship("Chat", back_populates="users")
__table_args__ = (
UniqueConstraint("user_id", "chat_id", name="unique_user_chat"),
)
def __repr__(self):
return f"<UserChat(user_id={self.user_id}, chat_id={self.chat_id})>"
# Database engine and session (async)
# Convert postgresql:// to postgresql+asyncpg://
async_database_url = Config.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
async_engine = create_async_engine(async_database_url, echo=False)
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Get async database session."""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
async def init_db() -> None:
"""Initialize database - create all tables."""
logger.info("Initializing database...")
try:
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database initialized successfully")
except Exception as e:
logger.error(f"Error initializing database: {e}", exc_info=True)
raise
def get_db_session():
"""Get an async database session generator."""
return get_db()