"""Database models and session management.""" from typing import AsyncGenerator from sqlalchemy import create_engine, BigInteger, String, Integer, Boolean, ForeignKey, Column from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session from sqlalchemy.schema import UniqueConstraint from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from bot.config import Config from bot.logger import get_logger logger = get_logger(__name__) Base = declarative_base() class User(Base): """User model - stores user information and birthday.""" __tablename__ = "users" user_id = Column(BigInteger, primary_key=True) username = Column(String, nullable=True) first_name = Column(String, nullable=False) birthday_day = Column(Integer, nullable=False) birthday_month = Column(Integer, nullable=False) birthday_year = Column(Integer, nullable=True) preference_theme = Column(String, nullable=False) # Relationships chats = relationship("UserChat", back_populates="user", cascade="all, delete-orphan") def __repr__(self): return f"" class Chat(Base): """Chat model - stores chat information.""" __tablename__ = "chats" chat_id = Column(BigInteger, primary_key=True) chat_title = Column(String, nullable=False) bot_is_admin = Column(Boolean, default=False, nullable=False) # Relationships users = relationship("UserChat", back_populates="chat", cascade="all, delete-orphan") def __repr__(self): return f"" class UserChat(Base): """Many-to-many relationship between users and chats.""" __tablename__ = "user_chats" user_id = Column(BigInteger, ForeignKey("users.user_id"), primary_key=True) chat_id = Column(BigInteger, ForeignKey("chats.chat_id"), primary_key=True) # Relationships user = relationship("User", back_populates="chats") chat = relationship("Chat", back_populates="users") __table_args__ = ( UniqueConstraint("user_id", "chat_id", name="unique_user_chat"), ) def __repr__(self): return f"" # Database engine and session (async) # Convert postgresql:// to postgresql+asyncpg:// async_database_url = Config.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1) async_engine = create_async_engine(async_database_url, echo=False) AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) async def get_db() -> AsyncGenerator[AsyncSession, None]: """Get async database session.""" async with AsyncSessionLocal() as session: try: yield session finally: await session.close() async def init_db() -> None: """Initialize database - create all tables.""" logger.info("Initializing database...") try: async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("Database initialized successfully") except Exception as e: logger.error(f"Error initializing database: {e}", exc_info=True) raise def get_db_session(): """Get an async database session generator.""" return get_db()