# tests/db/test_session.py
import pytest
import importlib
from unittest.mock import patch
# --- Test Suite for app.db.session ---
def test_sqlite_mode_initialization(monkeypatch):
"""
Tests if the session module correctly configures the SQLAlchemy engine for SQLite.
"""
# Arrange
monkeypatch.setenv("DB_MODE", "sqlite")
# Act
from app import config
from app.db import session
importlib.reload(config)
importlib.reload(session)
# Assert
assert session.engine.dialect.name == "sqlite"
assert "ai_hub.db" in session.engine.url.database
def test_postgres_mode_initialization(monkeypatch):
"""
Tests if the session module correctly configures the SQLAlchemy engine for PostgreSQL.
"""
# Arrange
monkeypatch.setenv("DB_MODE", "postgres")
custom_url = "postgresql://test_user:test_password@testhost/test_db"
monkeypatch.setenv("DATABASE_URL", custom_url)
# Act
from app import config
from app.db import session
importlib.reload(config)
importlib.reload(session)
# Assert
assert session.engine.url.drivername == "postgresql"
assert session.engine.url.username == "test_user"
assert session.engine.url.host == "testhost"
assert session.engine.url.database == "test_db"
# *** FIX: The patch target is changed to where SessionLocal is USED ***
@patch('app.api.dependencies.SessionLocal')
def test_get_db_yields_and_closes_session(mock_session_local):
"""
Tests if the get_db() dependency function yields a session and then closes it.
"""
# Arrange
from app.api.dependencies import get_db
mock_session = mock_session_local.return_value
db_generator = get_db()
# Act (Yield)
db_session_instance = next(db_generator)
# Assert (Yield)
# Now db_session_instance will be the mock you expect
assert db_session_instance is mock_session
mock_session.close.assert_not_called()
# Act (Close)
with pytest.raises(StopIteration):
next(db_generator)
# Assert (Close)
mock_session.close.assert_called_once()