135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
"""Database models and session management."""
|
|
import asyncio
|
|
from typing import AsyncGenerator
|
|
from sqlalchemy import create_engine, BigInteger, String, Integer, Boolean, ForeignKey, Column
|
|
from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session
|
|
from sqlalchemy.schema import UniqueConstraint
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
import socket
|
|
from bot.config import Config
|
|
from bot.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class User(Base):
|
|
"""User model - stores user information and birthday."""
|
|
__tablename__ = "users"
|
|
|
|
user_id = Column(BigInteger, primary_key=True)
|
|
username = Column(String, nullable=True)
|
|
first_name = Column(String, nullable=False)
|
|
birthday_day = Column(Integer, nullable=False)
|
|
birthday_month = Column(Integer, nullable=False)
|
|
birthday_year = Column(Integer, nullable=True)
|
|
preference_theme = Column(String, nullable=False)
|
|
|
|
# Relationships
|
|
chats = relationship("UserChat", back_populates="user", cascade="all, delete-orphan")
|
|
|
|
def __repr__(self):
|
|
return f"<User(user_id={self.user_id}, first_name={self.first_name})>"
|
|
|
|
|
|
class Chat(Base):
|
|
"""Chat model - stores chat information."""
|
|
__tablename__ = "chats"
|
|
|
|
chat_id = Column(BigInteger, primary_key=True)
|
|
chat_title = Column(String, nullable=False)
|
|
bot_is_admin = Column(Boolean, default=False, nullable=False)
|
|
|
|
# Relationships
|
|
users = relationship("UserChat", back_populates="chat", cascade="all, delete-orphan")
|
|
|
|
def __repr__(self):
|
|
return f"<Chat(chat_id={self.chat_id}, chat_title={self.chat_title})>"
|
|
|
|
|
|
class UserChat(Base):
|
|
"""Many-to-many relationship between users and chats."""
|
|
__tablename__ = "user_chats"
|
|
|
|
user_id = Column(BigInteger, ForeignKey("users.user_id"), primary_key=True)
|
|
chat_id = Column(BigInteger, ForeignKey("chats.chat_id"), primary_key=True)
|
|
|
|
# Relationships
|
|
user = relationship("User", back_populates="chats")
|
|
chat = relationship("Chat", back_populates="users")
|
|
|
|
__table_args__ = (
|
|
UniqueConstraint("user_id", "chat_id", name="unique_user_chat"),
|
|
)
|
|
|
|
def __repr__(self):
|
|
return f"<UserChat(user_id={self.user_id}, chat_id={self.chat_id})>"
|
|
|
|
|
|
# Database engine and session (async)
|
|
# Convert postgresql:// to postgresql+asyncpg://
|
|
async_database_url = Config.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
|
|
|
|
# Configure connection pool with better timeout settings
|
|
async_engine = create_async_engine(
|
|
async_database_url,
|
|
echo=False,
|
|
pool_pre_ping=True, # Verify connections before using them
|
|
pool_size=5, # Number of connections to maintain
|
|
max_overflow=10, # Maximum overflow connections
|
|
pool_recycle=3600, # Recycle connections after 1 hour
|
|
pool_timeout=30, # Timeout for getting connection from pool
|
|
connect_args={
|
|
"server_settings": {
|
|
"application_name": "bdbot",
|
|
},
|
|
"command_timeout": 60, # Command timeout
|
|
}
|
|
)
|
|
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get async database session."""
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
async def init_db() -> None:
|
|
"""Initialize database - create all tables with retry logic."""
|
|
logger.info("Initializing database...")
|
|
|
|
max_retries = 5
|
|
base_delay = 2 # Start with 2 seconds
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
async with async_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
logger.info("Database initialized successfully")
|
|
return
|
|
except (socket.gaierror, OSError) as e:
|
|
# DNS resolution or network errors
|
|
if attempt < max_retries - 1:
|
|
delay = base_delay * (2 ** attempt) # Exponential backoff: 2, 4, 8, 16, 32
|
|
logger.warning(
|
|
f"Database connection failed (attempt {attempt + 1}/{max_retries}): {e}. "
|
|
f"Retrying in {delay} seconds..."
|
|
)
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logger.error(f"Failed to connect to database after {max_retries} attempts: {e}", exc_info=True)
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error initializing database: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
def get_db_session():
|
|
"""Get an async database session generator."""
|
|
return get_db()
|