import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import DetachedInstanceError
from app.db.models import Base, Session, Message, Document, VectorMetadata
# Use an in-memory SQLite database for testing. This is fast and ensures
# each test starts with a clean slate.
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
# Create a database engine
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False}
)
# Create a configured "Session" class
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""
A pytest fixture that creates a new database session for a test.
It creates all tables at the start of the test and drops them at the end.
This ensures that each test is completely isolated.
"""
# Create all tables in the database
Base.metadata.create_all(bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
# Drop all tables after the test is finished
Base.metadata.drop_all(bind=engine)
def test_create_all_tables_with_relationships(db_session):
"""
Tests if all tables are created successfully and can be inspected.
"""
# Verify that all tables from the Base metadata were created
assert "sessions" in Base.metadata.tables
assert "messages" in Base.metadata.tables
assert "documents" in Base.metadata.tables
assert "vector_metadata" in Base.metadata.tables
# Check for foreign key constraints to ensure relationships are set up
session_table = Base.metadata.tables['sessions']
message_table = Base.metadata.tables['messages']
document_table = Base.metadata.tables['documents']
vector_metadata_table = Base.metadata.tables['vector_metadata']
assert len(message_table.foreign_keys) == 1
assert list(message_table.foreign_keys)[0].column.table.name == 'sessions'
assert len(vector_metadata_table.foreign_keys) == 2
fk_columns = [fk.column.table.name for fk in vector_metadata_table.foreign_keys]
assert 'documents' in fk_columns
assert 'sessions' in fk_columns
def test_create_and_retrieve_session(db_session):
"""
Tests the creation and retrieval of a Session object.
"""
# Create a new session object
new_session = Session(user_id="test-user-123", title="Test Session", provider_name="gemini")
# Add to session and commit to the database
db_session.add(new_session)
db_session.commit()
db_session.refresh(new_session)
# Retrieve the session from the database by its ID
retrieved_session = db_session.query(Session).filter(Session.id == new_session.id).first()
# Assert that the retrieved session matches the original
assert retrieved_session is not None
assert retrieved_session.user_id == "test-user-123"
assert retrieved_session.title == "Test Session"
assert retrieved_session.provider_name == "gemini"
def test_create_message_with_session_relationship(db_session):
"""
Tests the creation of a Message and its relationship with a Session.
"""
# First, create a session
new_session = Session(user_id="test-user-123")
db_session.add(new_session)
db_session.commit()
db_session.refresh(new_session)
# Then, create a message linked to that session
new_message = Message(
session_id=new_session.id,
sender="user",
content="Hello, world!",
token_count=3,
model_response_time=0
)
db_session.add(new_message)
db_session.commit()
db_session.refresh(new_message)
# Retrieve the message and verify its relationship
retrieved_message = db_session.query(Message).filter(Message.id == new_message.id).first()
assert retrieved_message is not None
assert retrieved_message.session_id == new_session.id
assert retrieved_message.content == "Hello, world!"
# Verify the relationship attribute works
assert retrieved_message.session.user_id == "test-user-123"
def test_create_document_and_vector_metadata_relationship(db_session):
"""
Tests the creation of a Document and its linked VectorMetadata.
"""
# Create a new document
new_document = Document(
user_id="test-user-123",
title="Sample Doc",
text="This is some sample text.",
status="ready"
)
db_session.add(new_document)
db_session.commit()
db_session.refresh(new_document)
# Create vector metadata linked to the document
vector_meta = VectorMetadata(
id=123,
document_id=new_document.id,
embedding_model="test-model"
)
db_session.add(vector_meta)
db_session.commit()
db_session.refresh(vector_meta)
# Retrieve the document and verify the relationship
retrieved_doc = db_session.query(Document).filter(Document.id == new_document.id).first()
assert retrieved_doc is not None
assert retrieved_doc.vector_metadata is not None
assert retrieved_doc.vector_metadata.id == 123
def test_cascade_delete_session_and_messages(db_session):
"""
Tests that deleting a session automatically deletes its associated messages due to cascading.
"""
# Create a session and some messages
new_session = Session(user_id="cascade-test")
db_session.add(new_session)
db_session.commit()
db_session.refresh(new_session)
message1 = Message(session_id=new_session.id, sender="user", content="Msg 1")
message2 = Message(session_id=new_session.id, sender="user", content="Msg 2")
db_session.add_all([message1, message2])
db_session.commit()
# Check that messages exist before deletion
assert db_session.query(Message).filter(Message.session_id == new_session.id).count() == 2
# Delete the session
db_session.delete(new_session)
db_session.commit()
# Check that the session is gone and the messages have been cascaded
assert db_session.query(Session).filter(Session.id == new_session.id).count() == 0
assert db_session.query(Message).filter(Message.session_id == new_session.id).count() == 0
def test_cascade_delete_document_and_vector_metadata(db_session):
"""
Tests that deleting a document automatically deletes its vector metadata due to cascading.
"""
# Create a document and a linked vector metadata entry
new_document = Document(user_id="cascade-test", title="Cascade Doc", text="Content")
db_session.add(new_document)
db_session.commit()
db_session.refresh(new_document)
vector_meta = VectorMetadata(id=999, document_id=new_document.id, embedding_model="test")
db_session.add(vector_meta)
db_session.commit()
db_session.refresh(vector_meta)
# Check that the vector metadata exists
assert db_session.query(VectorMetadata).filter(VectorMetadata.document_id == new_document.id).count() == 1
# Delete the document
db_session.delete(new_document)
db_session.commit()
# Check that the document is gone and the vector metadata has been cascaded
assert db_session.query(Document).filter(Document.id == new_document.id).count() == 0
assert db_session.query(VectorMetadata).filter(VectorMetadata.document_id == new_document.id).count() == 0