Newer
Older
cortex-hub / ai-hub / tests / db / test_models.py
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import DetachedInstanceError
from app.db.models import Base, Session, Message, Document, VectorMetadata

# Use an in-memory SQLite database for testing. This is fast and ensures
# each test starts with a clean slate.
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"

# Create a database engine
engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False}
)

# Create a configured "Session" class
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


@pytest.fixture(scope="function")
def db_session():
    """
    A pytest fixture that creates a new database session for a test.
    It creates all tables at the start of the test and drops them at the end.
    This ensures that each test is completely isolated.
    """
    # Create all tables in the database
    Base.metadata.create_all(bind=engine)
    db = TestingSessionLocal()
    try:
        yield db
    finally:
        db.close()
        # Drop all tables after the test is finished
        Base.metadata.drop_all(bind=engine)


def test_create_all_tables_with_relationships(db_session):
    """
    Tests if all tables are created successfully and can be inspected.
    """
    # Verify that all tables from the Base metadata were created
    assert "sessions" in Base.metadata.tables
    assert "messages" in Base.metadata.tables
    assert "documents" in Base.metadata.tables
    assert "vector_metadata" in Base.metadata.tables

    # Check for foreign key constraints to ensure relationships are set up
    session_table = Base.metadata.tables['sessions']
    message_table = Base.metadata.tables['messages']
    document_table = Base.metadata.tables['documents']
    vector_metadata_table = Base.metadata.tables['vector_metadata']

    assert len(message_table.foreign_keys) == 1
    assert list(message_table.foreign_keys)[0].column.table.name == 'sessions'

    assert len(vector_metadata_table.foreign_keys) == 2
    fk_columns = [fk.column.table.name for fk in vector_metadata_table.foreign_keys]
    assert 'documents' in fk_columns
    assert 'sessions' in fk_columns


def test_create_and_retrieve_session(db_session):
    """
    Tests the creation and retrieval of a Session object.
    """
    # Create a new session object
    new_session = Session(user_id="test-user-123", title="Test Session", model_name="gemini")
    
    # Add to session and commit to the database
    db_session.add(new_session)
    db_session.commit()
    db_session.refresh(new_session)

    # Retrieve the session from the database by its ID
    retrieved_session = db_session.query(Session).filter(Session.id == new_session.id).first()

    # Assert that the retrieved session matches the original
    assert retrieved_session is not None
    assert retrieved_session.user_id == "test-user-123"
    assert retrieved_session.title == "Test Session"
    assert retrieved_session.model_name == "gemini"


def test_create_message_with_session_relationship(db_session):
    """
    Tests the creation of a Message and its relationship with a Session.
    """
    # First, create a session
    new_session = Session(user_id="test-user-123")
    db_session.add(new_session)
    db_session.commit()
    db_session.refresh(new_session)

    # Then, create a message linked to that session
    new_message = Message(
        session_id=new_session.id,
        sender="user",
        content="Hello, world!",
        token_count=3,
        model_response_time=0
    )
    db_session.add(new_message)
    db_session.commit()
    db_session.refresh(new_message)

    # Retrieve the message and verify its relationship
    retrieved_message = db_session.query(Message).filter(Message.id == new_message.id).first()

    assert retrieved_message is not None
    assert retrieved_message.session_id == new_session.id
    assert retrieved_message.content == "Hello, world!"
    # Verify the relationship attribute works
    assert retrieved_message.session.user_id == "test-user-123"


def test_create_document_and_vector_metadata_relationship(db_session):
    """
    Tests the creation of a Document and its linked VectorMetadata.
    """
    # Create a new document
    new_document = Document(
        user_id="test-user-123",
        title="Sample Doc",
        text="This is some sample text.",
        status="ready"
    )
    db_session.add(new_document)
    db_session.commit()
    db_session.refresh(new_document)

    # Create vector metadata linked to the document
    vector_meta = VectorMetadata(
        document_id=new_document.id,
        faiss_index=123,
        embedding_model="test-model"
    )
    db_session.add(vector_meta)
    db_session.commit()
    db_session.refresh(vector_meta)

    # Retrieve the document and verify the relationship
    retrieved_doc = db_session.query(Document).filter(Document.id == new_document.id).first()
    assert retrieved_doc is not None
    assert retrieved_doc.vector_metadata is not None
    assert retrieved_doc.vector_metadata.faiss_index == 123


def test_cascade_delete_session_and_messages(db_session):
    """
    Tests that deleting a session automatically deletes its associated messages due to cascading.
    """
    # Create a session and some messages
    new_session = Session(user_id="cascade-test")
    db_session.add(new_session)
    db_session.commit()
    db_session.refresh(new_session)

    message1 = Message(session_id=new_session.id, sender="user", content="Msg 1")
    message2 = Message(session_id=new_session.id, sender="user", content="Msg 2")
    db_session.add_all([message1, message2])
    db_session.commit()

    # Check that messages exist before deletion
    assert db_session.query(Message).filter(Message.session_id == new_session.id).count() == 2

    # Delete the session
    db_session.delete(new_session)
    db_session.commit()

    # Check that the session is gone and the messages have been cascaded
    assert db_session.query(Session).filter(Session.id == new_session.id).count() == 0
    assert db_session.query(Message).filter(Message.session_id == new_session.id).count() == 0


def test_cascade_delete_document_and_vector_metadata(db_session):
    """
    Tests that deleting a document automatically deletes its vector metadata due to cascading.
    """
    # Create a document and a linked vector metadata entry
    new_document = Document(user_id="cascade-test", title="Cascade Doc", text="Content")
    db_session.add(new_document)
    db_session.commit()
    db_session.refresh(new_document)

    vector_meta = VectorMetadata(document_id=new_document.id, faiss_index=999, embedding_model="test")
    db_session.add(vector_meta)
    db_session.commit()
    db_session.refresh(vector_meta)

    # Check that the vector metadata exists
    assert db_session.query(VectorMetadata).filter(VectorMetadata.document_id == new_document.id).count() == 1

    # Delete the document
    db_session.delete(new_document)
    db_session.commit()

    # Check that the document is gone and the vector metadata has been cascaded
    assert db_session.query(Document).filter(Document.id == new_document.id).count() == 0
    assert db_session.query(VectorMetadata).filter(VectorMetadata.document_id == new_document.id).count() == 0