Newer
Older
cortex-hub / ai-hub / tests / db / test_session.py
# tests/db/test_session.py
import pytest
import importlib
from unittest.mock import patch

# --- Test Suite for app.db.session ---

def test_sqlite_mode_initialization(monkeypatch):
    """
    Tests if the session module correctly configures the SQLAlchemy engine for SQLite.
    """
    # Arrange
    monkeypatch.setenv("DB_MODE", "sqlite")

    # Act
    from app import config
    from app.db import session
    importlib.reload(config)
    importlib.reload(session)

    # Assert
    assert session.engine.dialect.name == "sqlite"
    assert "ai_hub.db" in session.engine.url.database

def test_postgres_mode_initialization(monkeypatch):
    """
    Tests if the session module correctly configures the SQLAlchemy engine for PostgreSQL.
    """
    # Arrange
    monkeypatch.setenv("DB_MODE", "postgres")
    custom_url = "postgresql://test_user:test_password@testhost/test_db"
    monkeypatch.setenv("DATABASE_URL", custom_url)

    # Act
    from app import config
    from app.db import session
    importlib.reload(config)
    importlib.reload(session)

    # Assert
    assert session.engine.url.drivername == "postgresql"
    assert session.engine.url.username == "test_user"
    assert session.engine.url.host == "testhost"
    assert session.engine.url.database == "test_db"

# *** FIX: The patch target is changed to where SessionLocal is USED ***
@patch('app.api.dependencies.SessionLocal')
def test_get_db_yields_and_closes_session(mock_session_local):
    """
    Tests if the get_db() dependency function yields a session and then closes it.
    """
    # Arrange
    from app.api.dependencies import get_db
    mock_session = mock_session_local.return_value
    db_generator = get_db()
    
    # Act (Yield)
    db_session_instance = next(db_generator)

    # Assert (Yield)
    # Now db_session_instance will be the mock you expect
    assert db_session_instance is mock_session
    mock_session.close.assert_not_called()

    # Act (Close)
    with pytest.raises(StopIteration):
        next(db_generator)

    # Assert (Close)
    mock_session.close.assert_called_once()