import os import pytest import importlib from sqlalchemy.orm import Session from sqlalchemy.exc import ResourceClosedError from sqlalchemy import text from unittest.mock import patch def test_sqlite_mode_initialization(monkeypatch): """ Tests if the database initializes in SQLite mode correctly. """ # Arrange: Set environment variable for SQLite mode monkeypatch.setenv("DB_MODE", "sqlite") # Act: Reload the module to apply the monkeypatched env vars from app.db import database importlib.reload(database) # Assert: Check if the configuration is correct for SQLite assert database.DB_MODE == "sqlite" assert "sqlite:///./ai_hub.db" in database.DATABASE_URL assert "connect_args" in database.engine_args assert database.engine_args["connect_args"] == {"check_same_thread": False} # Cleanup the created SQLite file after test, if it exists if os.path.exists("ai_hub.db"): os.remove("ai_hub.db") def test_postgres_mode_initialization(monkeypatch): """ Tests if the database initializes in PostgreSQL mode with a custom URL. """ # Arrange: Set env vars for PostgreSQL mode and a specific URL monkeypatch.setenv("DB_MODE", "postgres") monkeypatch.setenv("DATABASE_URL", "postgresql://test_user:test_password@testhost/test_db") # Act: Reload the module to apply the monkeypatched env vars from app.db import database importlib.reload(database) # Assert: Check if the configuration is correct for PostgreSQL assert database.DB_MODE == "postgres" assert database.DATABASE_URL == "postgresql://test_user:test_password@testhost/test_db" assert "pool_pre_ping" in database.engine_args def test_default_to_postgres_mode(monkeypatch): """ Tests if the system defaults to PostgreSQL mode when DB_MODE is not set. """ # Arrange: Ensure DB_MODE is not set monkeypatch.delenv("DB_MODE", raising=False) # Act: Reload the module to apply the monkeypatched env vars from app.db import database importlib.reload(database) # Assert: Check that it defaulted to postgres assert database.DB_MODE == "postgres" assert "postgresql://user:password@localhost/ai_hub_db" in database.DATABASE_URL @patch('app.db.database.SessionLocal') def test_get_db_yields_and_closes_session(mock_session_local, monkeypatch): """ Tests if the get_db() dependency function yields a valid, active session and correctly closes it afterward by mocking the session object. """ # Arrange: Get the actual get_db function from the module from app.db import database # Configure the mock session returned by SessionLocal() mock_session = mock_session_local.return_value db_generator = database.get_db() # Act # 1. Get the session object from the generator db_session_instance = next(db_generator) # Assert # 2. Check that the yielded object is our mock session assert db_session_instance is mock_session mock_session.close.assert_not_called() # The session should not be closed yet # 3. Exhaust the generator to trigger the 'finally' block with pytest.raises(StopIteration): next(db_generator) # 4. Assert that the close() method was called exactly once. mock_session.close.assert_called_once()