"""Database models and session management.""" import asyncio from typing import AsyncGenerator from sqlalchemy import create_engine, BigInteger, String, Integer, Boolean, ForeignKey, Column from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session from sqlalchemy.schema import UniqueConstraint from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker import socket from bot.config import Config from bot.logger import get_logger logger = get_logger(__name__) Base = declarative_base() class User(Base): """User model - stores user information and birthday.""" __tablename__ = "users" user_id = Column(BigInteger, primary_key=True) username = Column(String, nullable=True) first_name = Column(String, nullable=False) birthday_day = Column(Integer, nullable=False) birthday_month = Column(Integer, nullable=False) birthday_year = Column(Integer, nullable=True) preference_theme = Column(String, nullable=False) # Relationships chats = relationship("UserChat", back_populates="user", cascade="all, delete-orphan") def __repr__(self): return f"" class Chat(Base): """Chat model - stores chat information.""" __tablename__ = "chats" chat_id = Column(BigInteger, primary_key=True) chat_title = Column(String, nullable=False) bot_is_admin = Column(Boolean, default=False, nullable=False) # Relationships users = relationship("UserChat", back_populates="chat", cascade="all, delete-orphan") def __repr__(self): return f"" class UserChat(Base): """Many-to-many relationship between users and chats.""" __tablename__ = "user_chats" user_id = Column(BigInteger, ForeignKey("users.user_id"), primary_key=True) chat_id = Column(BigInteger, ForeignKey("chats.chat_id"), primary_key=True) # Relationships user = relationship("User", back_populates="chats") chat = relationship("Chat", back_populates="users") __table_args__ = ( UniqueConstraint("user_id", "chat_id", name="unique_user_chat"), ) def __repr__(self): return f"" # Database engine and session (async) # Convert postgresql:// to postgresql+asyncpg:// async_database_url = Config.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1) # Configure connection pool with better timeout settings async_engine = create_async_engine( async_database_url, echo=False, pool_pre_ping=True, # Verify connections before using them pool_size=5, # Number of connections to maintain max_overflow=10, # Maximum overflow connections pool_recycle=3600, # Recycle connections after 1 hour pool_timeout=30, # Timeout for getting connection from pool connect_args={ "server_settings": { "application_name": "bdbot", }, "command_timeout": 60, # Command timeout } ) AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) async def get_db() -> AsyncGenerator[AsyncSession, None]: """Get async database session.""" async with AsyncSessionLocal() as session: try: yield session finally: await session.close() async def init_db() -> None: """Initialize database - create all tables with retry logic.""" logger.info("Initializing database...") max_retries = 5 base_delay = 2 # Start with 2 seconds for attempt in range(max_retries): try: async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("Database initialized successfully") return except (socket.gaierror, OSError) as e: # DNS resolution or network errors if attempt < max_retries - 1: delay = base_delay * (2 ** attempt) # Exponential backoff: 2, 4, 8, 16, 32 logger.warning( f"Database connection failed (attempt {attempt + 1}/{max_retries}): {e}. " f"Retrying in {delay} seconds..." ) await asyncio.sleep(delay) else: logger.error(f"Failed to connect to database after {max_retries} attempts: {e}", exc_info=True) raise except Exception as e: logger.error(f"Error initializing database: {e}", exc_info=True) raise def get_db_session(): """Get an async database session generator.""" return get_db()