86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
"""Database models and session management."""
|
|
from typing import Generator
|
|
from sqlalchemy import create_engine, BigInteger, String, Integer, Boolean, ForeignKey, Column
|
|
from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session
|
|
from sqlalchemy.schema import UniqueConstraint
|
|
from config import Config
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class User(Base):
|
|
"""User model - stores user information and birthday."""
|
|
__tablename__ = "users"
|
|
|
|
user_id = Column(BigInteger, primary_key=True)
|
|
username = Column(String, nullable=True)
|
|
first_name = Column(String, nullable=False)
|
|
birthday_day = Column(Integer, nullable=False)
|
|
birthday_month = Column(Integer, nullable=False)
|
|
birthday_year = Column(Integer, nullable=True)
|
|
preference_theme = Column(String, nullable=False)
|
|
|
|
# Relationships
|
|
chats = relationship("UserChat", back_populates="user", cascade="all, delete-orphan")
|
|
|
|
def __repr__(self):
|
|
return f"<User(user_id={self.user_id}, first_name={self.first_name})>"
|
|
|
|
|
|
class Chat(Base):
|
|
"""Chat model - stores chat information."""
|
|
__tablename__ = "chats"
|
|
|
|
chat_id = Column(BigInteger, primary_key=True)
|
|
chat_title = Column(String, nullable=False)
|
|
bot_is_admin = Column(Boolean, default=False, nullable=False)
|
|
|
|
# Relationships
|
|
users = relationship("UserChat", back_populates="chat", cascade="all, delete-orphan")
|
|
|
|
def __repr__(self):
|
|
return f"<Chat(chat_id={self.chat_id}, chat_title={self.chat_title})>"
|
|
|
|
|
|
class UserChat(Base):
|
|
"""Many-to-many relationship between users and chats."""
|
|
__tablename__ = "user_chats"
|
|
|
|
user_id = Column(BigInteger, ForeignKey("users.user_id"), primary_key=True)
|
|
chat_id = Column(BigInteger, ForeignKey("chats.chat_id"), primary_key=True)
|
|
|
|
# Relationships
|
|
user = relationship("User", back_populates="chats")
|
|
chat = relationship("Chat", back_populates="users")
|
|
|
|
__table_args__ = (
|
|
UniqueConstraint("user_id", "chat_id", name="unique_user_chat"),
|
|
)
|
|
|
|
def __repr__(self):
|
|
return f"<UserChat(user_id={self.user_id}, chat_id={self.chat_id})>"
|
|
|
|
|
|
# Database engine and session
|
|
engine = create_engine(Config.DATABASE_URL, echo=False)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
"""Get database session."""
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def init_db() -> None:
|
|
"""Initialize database - create all tables."""
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
def get_db_session() -> Session:
|
|
"""Get a database session (for use without context manager)."""
|
|
return SessionLocal()
|