diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 72e11ce..9464fa7 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -13,7 +13,7 @@ from app.db.session import create_db_and_tables from app.api.routes.api import create_api_router from app.utils import print_config -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer, get_db from app.core.services.session import SessionService from app.core.services.tts import TTSService from app.core.services.stt import STTService # NEW: Added the missing import for STTService @@ -38,7 +38,7 @@ print("Application shutdown...") # Access the vector_store from the application state to save it if hasattr(app.state, 'vector_store'): - app.state.vector_store.save_index_and_metadata() + app.state.vector_store.save_index() def create_app() -> FastAPI: """ @@ -68,8 +68,7 @@ vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, - embedder=embedder # Pass the instantiated embedder object, - + embedder=embedder ) # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py index 3c3c202..216bfc4 100644 --- a/ai-hub/app/core/retrievers/faiss_db_retriever.py +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -17,22 +17,28 @@ Retrieves document text by first searching the FAISS index and then fetching the corresponding documents from the database. """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) + # Pass db session explicitly to vector store method + faiss_ids = self.vector_store.search_similar_documents(db_session=db, query_text=query, k=3) context_docs_text = [] if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database + # self.print_all_vector_metadata(db) + # Get document_ids using the FAISS IDs document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) + models.VectorMetadata.document_id.in_(faiss_ids) ).all() document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table + # Retrieve full documents by their IDs context_docs = db.query(models.Document).filter( models.Document.id.in_(document_ids) ).all() - + context_docs_text = [doc.text for doc in context_docs] - return context_docs_text \ No newline at end of file + return context_docs_text + + # def print_all_vector_metadata(self, db: Session): + # all_metadata = db.query(models.VectorMetadata).all() + # for metadata in all_metadata: + # print(f"VectorMetadata - ID: {metadata.id}, Document ID: {metadata.document_id}, Embedding Model: {metadata.embedding_model}") \ No newline at end of file diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py index 3634f3e..1efa5e8 100644 --- a/ai-hub/app/core/services/document.py +++ b/ai-hub/app/core/services/document.py @@ -17,26 +17,28 @@ def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: """ - Adds a new document to the database and its vector embedding to the FAISS index. + Adds a new document and its vector metadata. """ try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) - + + # Determine embedding model name if needed embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( + + # Pass db_session explicitly as positional argument, and document_id + faiss_index_id = self.vector_store.add_document( + text= document_db.text, document_id=document_db.id, - faiss_index=faiss_index, + db_session=db, embedding_model=embedding_model_name ) - db.add(vector_metadata) - db.commit() + return document_db.id - except SQLAlchemyError as e: + + except SQLAlchemyError: db.rollback() raise diff --git a/ai-hub/app/core/vector_store/faiss_store.py b/ai-hub/app/core/vector_store/faiss_store.py index 5846e7c..dc50a96 100644 --- a/ai-hub/app/core/vector_store/faiss_store.py +++ b/ai-hub/app/core/vector_store/faiss_store.py @@ -2,77 +2,120 @@ import logging import faiss import numpy as np -import pickle from typing import List, Optional, Dict, Any - +from sqlalchemy.orm import Session +from sqlalchemy import select from .base import VectorStore -from .utils import save_faiss_index, load_faiss_index +from app.db.models import VectorMetadata, Document # Import your SQLAlchemy models class FaissVectorStore(VectorStore): """ - An in-memory vector store using the FAISS library with support for filtering - by metadata tags and persistence of both the index and the tags. + A FAISS vector store that uses a relational database (via SQLAlchemy) + for persistence and filtering of document metadata. """ + def __init__(self, index_file_path: str, dimension: int, embedder): self.index_file_path = index_file_path self.dimension = dimension self.embedder = embedder + self.index = None + self.doc_id_map = [] - self.doc_tags = {} # Metadata per document - self.doc_vectors = {} # Store vectors for filtered search - + def initialize_index(self, db_session: Session): + """Initializes the FAISS index and syncs it with the database.""" if os.path.exists(self.index_file_path): logging.info(f"Loading FAISS index from {self.index_file_path}") self.index = faiss.read_index(self.index_file_path) - self.load_metadata() - self.load_vectors() - self.doc_id_map = list(self.doc_tags.keys()) else: logging.info("Creating a new FAISS index.") - quantizer = faiss.IndexFlatL2(dimension) + quantizer = faiss.IndexFlatL2(self.dimension) self.index = faiss.IndexIDMap(quantizer) - self.doc_id_map = [] - def add_document(self, text: str, tags: Optional[Dict[str, Any]] = None) -> int: + self.sync_with_db(db_session) + + def sync_with_db(self, db_session: Session): + """Syncs the in-memory FAISS ID map with the database metadata.""" + logging.info("Synchronizing FAISS index with database metadata.") + faiss_ids_from_db = db_session.execute(select(VectorMetadata.id)).scalars().all() + self.doc_id_map = faiss_ids_from_db + + def add_document(self, text: str, document_id: int, + session_id: Optional[int] = None, + embedding_model: str = "default_model", + db_session: Session = None) -> int: + """Embeds and adds a single document to FAISS and DB.""" + if db_session is None: + raise ValueError("db_session must be provided") + logging.debug("Embedding document text for FAISS index...") vector = self.embedder.embed_text(text).reshape(1, -1).astype('float32') - new_doc_id = self.index.ntotal - self.index.add_with_ids(vector, np.array([new_doc_id], dtype='int64')) + metadata_entry = VectorMetadata( + document_id=document_id, + session_id=session_id, + embedding_model=embedding_model, + ) + db_session.add(metadata_entry) + db_session.flush() - self.doc_id_map.append(new_doc_id) - self.doc_tags[new_doc_id] = tags if tags else {} - self.doc_vectors[new_doc_id] = vector.flatten() + faiss_index_id = metadata_entry.id + if self.index is None: + self.initialize_index(db_session) + self.index.add_with_ids(vector, np.array([faiss_index_id], dtype='int64')) + self.doc_id_map.append(faiss_index_id) - self.save_index_and_metadata() - logging.info(f"Document added to FAISS index with ID: {new_doc_id}") + db_session.commit() + self.save_index() - return new_doc_id + logging.info(f"Document added to FAISS index with ID: {faiss_index_id}") + return faiss_index_id - def add_multiple_documents(self, texts: List[str], tags: Optional[List[Dict[str, Any]]] = None) -> List[int]: + def add_multiple_documents(self, texts: List[str], document_ids: List[int], + session_id: Optional[int] = None, + embedding_model: str = "default_model", + db_session: Session = None) -> List[int]: + """Adds multiple documents to FAISS and DB.""" + if db_session is None: + raise ValueError("db_session must be provided") + + if len(texts) != len(document_ids): + raise ValueError("The number of texts must match the number of document IDs.") + logging.debug("Embedding multiple document texts for FAISS index...") - vectors = np.vstack([ - self.embedder.embed_text(text).reshape(1, -1) for text in texts - ]).astype('float32') + vectors = self.embedder.embed_multiple_texts(texts).astype('float32') - start_id = self.index.ntotal - new_doc_ids = list(range(start_id, start_id + len(texts))) + new_metadata_entries = [] + for doc_id in document_ids: + new_metadata_entries.append(VectorMetadata( + document_id=doc_id, + session_id=session_id, + embedding_model=embedding_model, + )) - self.index.add_with_ids(vectors, np.array(new_doc_ids, dtype='int64')) + db_session.add_all(new_metadata_entries) + db_session.flush() - self.doc_id_map.extend(new_doc_ids) - for i, doc_id in enumerate(new_doc_ids): - self.doc_tags[doc_id] = tags[i] if tags and len(tags) == len(texts) else {} - self.doc_vectors[doc_id] = vectors[i] + new_faiss_ids = [entry.id for entry in new_metadata_entries] + if self.index is None: + self.initialize_index(db_session) + self.index.add_with_ids(vectors, np.array(new_faiss_ids, dtype='int64')) + self.doc_id_map.extend(new_faiss_ids) - self.save_index_and_metadata() - logging.info(f"Added {len(new_doc_ids)} documents to FAISS index.") + db_session.commit() + self.save_index() - return new_doc_ids + logging.info(f"Added {len(new_faiss_ids)} documents to FAISS index.") + return new_faiss_ids - def search_similar_documents(self, query_text: str, k: int = 5, prefilter_tags: Optional[Dict[str, Any]] = None) -> List[int]: - logging.debug(f"Searching FAISS index for similar documents to query: '{query_text[:50]}...'") + def search_similar_documents(self, query_text: str, k: int = 5, + prefilter_tags: Optional[Dict[str, Any]] = None, + db_session: Session = None) -> List[int]: + """Searches FAISS for similar documents, with optional DB filter.""" + if db_session is None: + raise ValueError("db_session must be provided") + + logging.debug("Searching FAISS index with database-backed filtering.") if self.index.ntotal == 0: logging.warning("FAISS index is empty, no documents to search.") return [] @@ -80,78 +123,37 @@ query_vector = self.embedder.embed_text(query_text).reshape(1, -1).astype('float32') if prefilter_tags: - valid_ids = [ - doc_id for doc_id, tags in self.doc_tags.items() - if all(tags.get(key) == value for key, value in prefilter_tags.items()) - ] + db_query = select(VectorMetadata.id) + for key, value in prefilter_tags.items(): + if hasattr(VectorMetadata, key): + db_query = db_query.where(getattr(VectorMetadata, key) == value) + else: + logging.warning(f"Metadata key '{key}' not found in VectorMetadata model.") + return [] - if not valid_ids: + valid_faiss_ids = db_session.execute(db_query).scalars().all() + + if not valid_faiss_ids: logging.warning("No documents match the filter criteria.") return [] - try: - filtered_vectors = np.vstack([ - self.doc_vectors[doc_id].reshape(1, -1) - for doc_id in valid_ids - ]).astype('float32') - - temp_index = faiss.IndexFlatL2(self.dimension) - temp_index.add(filtered_vectors) - - D, I = temp_index.search(query_vector, min(k, len(valid_ids))) - result_ids = [int(valid_ids[i]) for i in I.flatten() if i >= 0] - - except Exception as e: - logging.error(f"Error during filtered search: {e}") - return [] + id_selector = faiss.IDSelectorBatch(valid_faiss_ids) + D, I = self.index.search(query_vector, min(k, len(valid_faiss_ids)), + params=faiss.SearchParameters(sel=id_selector)) + result_faiss_ids = [int(i) for i in I.flatten() if i >= 0] else: D, I = self.index.search(query_vector, k) - result_ids = [int(i) for i in I.flatten() if i >= 0] + result_faiss_ids = [int(i) for i in I.flatten() if i >= 0] - logging.info(f"Search complete, found {len(result_ids)} similar documents.") - return result_ids + result_document_ids = db_session.execute( + select(VectorMetadata.document_id).where(VectorMetadata.id.in_(result_faiss_ids)) + ).scalars().all() - def save_index_and_metadata(self): + logging.info(f"Search complete, found {len(result_document_ids)} similar documents.") + return result_document_ids + + def save_index(self): + """Saves the FAISS index to the file system.""" if self.index: logging.info(f"Saving FAISS index to {self.index_file_path}") faiss.write_index(self.index, self.index_file_path) - - # Save metadata - tags_file_path = self.index_file_path + ".tags" - with open(tags_file_path, 'wb') as f: - pickle.dump(self.doc_tags, f) - logging.info(f"Saved metadata to {tags_file_path}") - - # Save vectors - vectors_file_path = self.index_file_path + ".vecs" - with open(vectors_file_path, 'wb') as f: - pickle.dump(self.doc_vectors, f) - logging.info(f"Saved document vectors to {vectors_file_path}") - - def load_metadata(self): - tags_file_path = self.index_file_path + ".tags" - if os.path.exists(tags_file_path): - try: - with open(tags_file_path, 'rb') as f: - self.doc_tags = pickle.load(f) - logging.info(f"Loaded metadata from {tags_file_path}") - except Exception as e: - logging.error(f"Failed to load metadata file: {e}") - self.doc_tags = {} - else: - logging.warning("Metadata file not found, initializing empty tags dictionary.") - self.doc_tags = {} - - def load_vectors(self): - vectors_file_path = self.index_file_path + ".vecs" - if os.path.exists(vectors_file_path): - try: - with open(vectors_file_path, 'rb') as f: - self.doc_vectors = pickle.load(f) - logging.info(f"Loaded vectors from {vectors_file_path}") - except Exception as e: - logging.error(f"Failed to load vectors file: {e}") - self.doc_vectors = {} - else: - logging.warning("Vectors file not found, initializing empty vector dictionary.") - self.doc_vectors = {} diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index 2d3476c..2cf0ebd 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -130,17 +130,15 @@ SQLAlchemy model for the 'vector_metadata' table. This table links a document to its corresponding vector representation - in the FAISS index. It is critical for syncing data between the - relational database and the vector store. + in the FAISS index. The primary key `id` of this table serves as the + vector ID in the FAISS store, making the `faiss_index` column redundant. """ __tablename__ = 'vector_metadata' - # Primary key for the metadata entry. + # Primary key for the metadata entry. This will also be the FAISS index. id = Column(Integer, primary_key=True, index=True) # Foreign key that links this metadata entry back to its Document. document_id = Column(Integer, ForeignKey('documents.id'), unique=True) - # The index number in the FAISS vector store where the vector for this document is stored. - faiss_index = Column(Integer, nullable=False, index=True) # Foreign key to link this vector metadata to a specific session. # This is crucial for retrieving relevant RAG context for a given conversation. session_id = Column(Integer, ForeignKey('sessions.id'), nullable=True) @@ -156,4 +154,4 @@ """ Provides a helpful string representation of the object for debugging. """ - return f"" + return f"" diff --git a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py index 00a3735..e070219 100644 --- a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py +++ b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py @@ -1,55 +1,37 @@ import pytest -from sqlalchemy import create_engine, Column, Integer, String, ForeignKey -from sqlalchemy.orm import sessionmaker, declarative_base, relationship +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker from typing import List +from datetime import datetime -# Mock the required models and FaissVectorStore for testing purposes -Base = declarative_base() +# Import actual models and base +from app.db.models import Base, Document, VectorMetadata +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever -class Document(Base): - __tablename__ = "documents" - id = Column(Integer, primary_key=True, index=True) - title = Column(String) # Add the missing columns - text = Column(String) - source_url = Column(String) - author = Column(String) - status = Column(String) - created_at = Column(String) - user_id = Column(Integer) - vectors = relationship("VectorMetadata", back_populates="document") -class VectorMetadata(Base): - __tablename__ = "vector_metadata" - id = Column(Integer, primary_key=True, index=True) - faiss_index = Column(Integer, unique=True) - document_id = Column(Integer, ForeignKey("documents.id")) - document = relationship("Document", back_populates="vectors") - -# A mock version of the FaissVectorStore for testing -class FaissVectorStore: - def __init__(self, encoder): +# --- Mock FaissVectorStore --- +class MockFaissVectorStore: + def __init__(self): self._index = {} - self._next_id = 0 def add_document_to_index(self, text: str) -> int: - self._index[self._next_id] = text - current_id = self._next_id - self._next_id += 1 - return current_id + new_id = len(self._index) + self._index[new_id] = text + return new_id - def search_similar_documents(self, query: str, k: int = 1) -> List[int]: - # This is a mock; in a real scenario, this would perform a vector search. - # Here, we'll assume it returns the IDs we expect for the test. - # This method is often patched in the tests. + def search_similar_documents(self, query_text: str, k: int = 1, db_session=None, prefilter_tags=None) -> List[int]: + # Always return vector ID 0 for test simplicity return [0] -# --- E2E test setup and fixtures --- + +# --- Database setup --- SQLALCHEMY_DATABASE_URL = "sqlite:///./data/test.db" engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + @pytest.fixture(scope="function") def db_session(): """Provides a clean database session for each test.""" @@ -61,67 +43,54 @@ db.close() Base.metadata.drop_all(bind=engine) + @pytest.fixture(scope="function") def faiss_store(): - """Provides a fresh FaissVectorStore instance for each test.""" - class MockEncoder: - def encode(self, text): - return [1.0] * 768 - - return FaissVectorStore(MockEncoder()) + """Provides a mock FaissVectorStore for each test.""" + return MockFaissVectorStore() -# --- E2E test cases --- -# Assuming FaissDBRetriever and its dependencies are correctly imported -# You need to make sure the import path is correct for your project structure -from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +# --- Tests --- def test_retrieve_context_successful(db_session, faiss_store): - """ - Tests that the retriever successfully finds and returns the correct - document text based on a FAISS search. - """ - # 1. Setup - Create documents and vectors - doc1 = Document( - title="Sample Title", # Provide values for the new columns - text="The quick brown fox jumps over the lazy dog." + # Step 1: Create a document + doc = Document( + id=123, # FAISS index is stored in `document_id` + title="Sample Title", + text="The quick brown fox jumps over the lazy dog.", + source_url="", + author="", + status="ready", + created_at=datetime(2023, 1, 1, 0, 0, 0), # ✅ FIXED + user_id="test_user" ) - db_session.add(doc1) - db_session.commit() - db_session.refresh(doc1) - - # Mock the FAISS store to return a predictable ID - faiss_id1 = 123 - faiss_store.add_document_to_index = lambda text: faiss_id1 - - vec_meta1 = VectorMetadata(document_id=doc1.id, faiss_index=faiss_id1) - db_session.add(vec_meta1) + db_session.add(doc) db_session.commit() - # 2. Execution - Create and run the retriever + # Step 2: Create VectorMetadata with id matching FAISS ID returned + vector_metadata = VectorMetadata( + id=0, + document_id=123, # FAISS index is stored in `document_id` + embedding_model="mock" + ) + db_session.add(vector_metadata) + db_session.commit() + + # Step 3: Use retriever retriever = FaissDBRetriever(vector_store=faiss_store) - - # We'll mock the search to return the ID of our specific document - faiss_store.search_similar_documents = lambda query, k: [faiss_id1] + faiss_store.search_similar_documents = lambda query_text, k, db_session=None, prefilter_tags=None: [123] + retrieved_context = retriever.retrieve_context(query="query for fox", db=db_session) - # 3. Assertion - Verify the result + # Step 4: Assertions assert len(retrieved_context) == 1 assert retrieved_context[0] == "The quick brown fox jumps over the lazy dog." + def test_retrieve_context_no_match(db_session, faiss_store): - """ - Tests that the retriever returns an empty list when no matching - documents are found in the FAISS index. - """ - # 1. Setup - No documents or vectors are added to the database. - - # 2. Execution - Create and run the retriever retriever = FaissDBRetriever(vector_store=faiss_store) - - # Mock the search to return an empty list - faiss_store.search_similar_documents = lambda query, k: [] - - retrieved_context = retriever.retrieve_context(query="non-existent query", db=db_session) - # 3. Assertion - Verify the result is an empty list - assert retrieved_context == [] \ No newline at end of file + # Override to return empty list + faiss_store.search_similar_documents = lambda query_text, k, db_session=None, prefilter_tags=None: [] + + retrieved_context = retriever.retrieve_context(query="no match", db=db_session) + assert retrieved_context == [] diff --git a/ai-hub/tests/core/services/test_document.py b/ai-hub/tests/core/services/test_document.py index cdf8b44..850b0a9 100644 --- a/ai-hub/tests/core/services/test_document.py +++ b/ai-hub/tests/core/services/test_document.py @@ -58,13 +58,13 @@ mock_db.add.assert_any_call(mock_document_model_instance) mock_db.commit.assert_called() mock_db.refresh.assert_called_with(mock_document_model_instance) - document_service.vector_store.add_document.assert_called_once_with("Test text.") - mock_vector_metadata_model.assert_called_once_with( - document_id=1, - faiss_index=123, - embedding_model="mock_embedder" - ) - mock_db.add.assert_any_call(mock_vector_metadata_model.return_value) + document_service.vector_store.add_document.assert_called_once_with(text="Test text.",document_id=1,db_session=mock_db,embedding_model="mock_embedder") + # mock_vector_metadata_model.assert_called_once_with( + # document_id=1, + # faiss_index=123, + # embedding_model="mock_embedder" + # ) + # mock_db.add.assert_any_call(mock_vector_metadata_model.return_value) def test_add_document_sql_error(document_service: DocumentService): diff --git a/ai-hub/tests/core/vector_store/test_faiss_store.py b/ai-hub/tests/core/vector_store/test_faiss_store.py index ee43b22..316fd21 100644 --- a/ai-hub/tests/core/vector_store/test_faiss_store.py +++ b/ai-hub/tests/core/vector_store/test_faiss_store.py @@ -2,23 +2,48 @@ import shutil import tempfile import pytest -import pickle +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker import numpy as np +from typing import List, Dict from app.core.vector_store.faiss_store import FaissVectorStore - +from app.db.models import Base, VectorMetadata, Document, Session # ----------------------------- # Fixtures # ----------------------------- +@pytest.fixture(scope="function") +def engine(): + return create_engine("sqlite:///:memory:") + +@pytest.fixture(scope="function") +def setup_database(engine): + Base.metadata.create_all(engine) + yield + Base.metadata.drop_all(engine) + +@pytest.fixture +def db_session(setup_database, engine): + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + session = SessionLocal() + session.begin() + try: + yield session + finally: + session.rollback() + session.close() + @pytest.fixture def mock_embedder(): class MockEmbedder: def embed_text(self, text): - # Return a deterministic fake vector based on hash of text np.random.seed(abs(hash(text)) % 2**32) return np.random.rand(768).astype('float32') + + def embed_multiple_texts(self, texts): + return np.vstack([self.embed_text(text) for text in texts]) return MockEmbedder() @pytest.fixture @@ -29,91 +54,154 @@ shutil.rmtree(tmp_dir) @pytest.fixture -def faiss_store(temp_faiss_file, mock_embedder): - return FaissVectorStore(index_file_path=temp_faiss_file, dimension=768, embedder=mock_embedder) +def faiss_store(temp_faiss_file, mock_embedder, db_session): + store = FaissVectorStore( + index_file_path=temp_faiss_file, + dimension=768, + embedder=mock_embedder + ) + store.initialize_index(db_session) + return store +# ----------------------------- +# Helper Function for Test Setup +# ----------------------------- +def create_test_data(db_session, faiss_store, texts_and_metadata: List[Dict]): + dummy_session = Session(user_id="test_user", title="Test Session") + db_session.add(dummy_session) + db_session.commit() + + document_ids = [] + texts = [] + + for item in texts_and_metadata: + dummy_doc = Document(title=item.get("title", "Test Doc"), text=item["text"]) + db_session.add(dummy_doc) + db_session.flush() + document_ids.append(dummy_doc.id) + texts.append(dummy_doc.text) + + db_session.commit() + + faiss_store.add_multiple_documents( + texts=texts, + document_ids=document_ids, + session_id=dummy_session.id, + embedding_model="test_model", + db_session=db_session + ) + + return document_ids, dummy_session.id # ----------------------------- # Tests # ----------------------------- -def test_add_document(faiss_store): +def test_add_document(faiss_store, db_session): test_text = "This is a test document." - test_tags = {"author": "John Doe", "year": 2023} + test_doc = Document(title="Test", text=test_text) + db_session.add(test_doc) + db_session.commit() + document_id = test_doc.id assert faiss_store.index.ntotal == 0 - assert faiss_store.doc_tags == {} + assert db_session.query(VectorMetadata).count() == 0 - doc_id = faiss_store.add_document(test_text, tags=test_tags) + faiss_index_id = faiss_store.add_document(test_text, document_id=document_id, db_session=db_session) assert faiss_store.index.ntotal == 1 - assert doc_id == 0 assert os.path.exists(faiss_store.index_file_path) - assert os.path.exists(faiss_store.index_file_path + ".tags") - assert os.path.exists(faiss_store.index_file_path + ".vecs") - assert faiss_store.doc_tags[doc_id] == test_tags - assert isinstance(faiss_store.doc_vectors[doc_id], np.ndarray) + metadata_entry = db_session.query(VectorMetadata).filter_by(id=faiss_index_id).first() + assert metadata_entry is not None + assert metadata_entry.document_id == document_id + assert metadata_entry.id == faiss_index_id -def test_add_multiple_documents(faiss_store): - docs = ["Doc 1", "Doc 2", "Doc 3"] - tags = [{"type": "a"}, {"type": "b"}, {"type": "a"}] +def test_add_multiple_documents(faiss_store, db_session): + docs = [ + {"text": "Doc 1"}, + {"text": "Doc 2"}, + {"text": "Doc 3"} + ] - assert faiss_store.index.ntotal == 0 - - doc_ids = faiss_store.add_multiple_documents(docs, tags=tags) + doc_ids, _ = create_test_data(db_session, faiss_store, docs) assert faiss_store.index.ntotal == 3 - assert doc_ids == [0, 1, 2] - assert faiss_store.doc_tags[0] == {"type": "a"} - assert faiss_store.doc_tags[1] == {"type": "b"} - assert faiss_store.doc_tags[2] == {"type": "a"} - assert all(isinstance(faiss_store.doc_vectors[i], np.ndarray) for i in doc_ids) + assert db_session.query(VectorMetadata).count() == 3 + db_faiss_ids = db_session.query(VectorMetadata.id).order_by(VectorMetadata.id).all() + assert [row[0] for row in db_faiss_ids] == faiss_store.doc_id_map -def test_load_existing_index_with_metadata(temp_faiss_file, mock_embedder): +def test_load_existing_index_and_sync_with_db(temp_faiss_file, mock_embedder, db_session): store1 = FaissVectorStore(temp_faiss_file, 768, mock_embedder) - store1.add_document("Persistence test with tags.", tags={"status": "complete"}) + store1.initialize_index(db_session) + create_test_data(db_session, store1, [{"text": "Persistence test 1"}, {"text": "Persistence test 2"}]) - # Reload store2 = FaissVectorStore(temp_faiss_file, 768, mock_embedder) + store2.initialize_index(db_session) - assert store2.index.ntotal == 1 - assert store2.doc_id_map == [0] - assert store2.doc_tags[0] == {"status": "complete"} - assert isinstance(store2.doc_vectors[0], np.ndarray) + assert store2.index.ntotal == 2 + assert db_session.query(VectorMetadata).count() == 2 + expected_ids = db_session.query(VectorMetadata.id).order_by(VectorMetadata.id).all() + expected_ids = [row[0] for row in expected_ids] + assert sorted(store2.doc_id_map) == sorted(expected_ids) -def test_search_similar_documents_without_filter(faiss_store): - faiss_store.add_document("The sun is a star.", tags={"category": "astronomy"}) - faiss_store.add_document("Mars is a planet.", tags={"category": "astronomy"}) - faiss_store.add_document("The moon orbits the Earth.", tags={"category": "astronomy"}) +def test_search_similar_documents_without_filter(faiss_store, db_session): + docs = [ + {"text": "The sun is a star.", "title": "astronomy"}, + {"text": "Mars is a planet.", "title": "astronomy"}, + {"text": "The moon orbits the Earth.", "title": "astronomy"} + ] + create_test_data(db_session, faiss_store, docs) - results = faiss_store.search_similar_documents("What is a star?", k=2) + results = faiss_store.search_similar_documents("What is a star?", k=2, db_session=db_session) assert len(results) == 2 - assert all(isinstance(doc_id, int) for doc_id in results) + assert isinstance(results[0], int) +def test_search_similar_documents_with_filter(faiss_store, db_session): + dummy_session1 = Session(user_id="user1", title="Session 1") + dummy_session2 = Session(user_id="user2", title="Session 2") + db_session.add_all([dummy_session1, dummy_session2]) + db_session.commit() -def test_search_similar_documents_with_filter(faiss_store): - faiss_store.add_document("Python is a programming language.", tags={"type": "programming"}) - faiss_store.add_document("A dog is a loyal pet.", tags={"type": "animal"}) - faiss_store.add_document("Java is another programming language.", tags={"type": "programming"}) + doc1 = Document(title="Python", text="Python is a programming language.") + doc2 = Document(title="Dog", text="A dog is a loyal pet.") + doc3 = Document(title="Java", text="Java is another programming language.") + db_session.add_all([doc1, doc2, doc3]) + db_session.commit() + + faiss_store.add_document(doc1.text, doc1.id, session_id=dummy_session1.id, db_session=db_session) + faiss_store.add_document(doc2.text, doc2.id, session_id=dummy_session2.id, db_session=db_session) + faiss_store.add_document(doc3.text, doc3.id, session_id=dummy_session1.id, db_session=db_session) results = faiss_store.search_similar_documents( - "Which is a programming language?", k=2, prefilter_tags={"type": "programming"} + "Which is a programming language?", k=2, + prefilter_tags={"session_id": dummy_session1.id}, + db_session=db_session ) assert len(results) == 2 for doc_id in results: - assert faiss_store.doc_tags[doc_id]["type"] == "programming" + metadata = db_session.query(VectorMetadata).filter_by(document_id=doc_id).first() + assert metadata.session_id == dummy_session1.id - -def test_search_with_no_matching_filter(faiss_store): - faiss_store.add_document("A document about cats.", tags={"species": "feline"}) +def test_search_with_no_matching_filter(faiss_store, db_session): + dummy_session = Session(user_id="user", title="Some Session") + dummy_doc = Document(title="Cat", text="A document about cats.") + db_session.add_all([dummy_session, dummy_doc]) + db_session.commit() + faiss_store.add_document(dummy_doc.text, dummy_doc.id, session_id=dummy_session.id, db_session=db_session) results = faiss_store.search_similar_documents( - "What is a dog?", k=5, prefilter_tags={"species": "canine"} + "What is a dog?", k=5, prefilter_tags={"session_id": 9999}, db_session=db_session ) assert len(results) == 0 + +def test_search_with_invalid_filter_key(faiss_store, db_session): + results = faiss_store.search_similar_documents( + "A query.", k=1, prefilter_tags={"invalid_key": "some_value"}, db_session=db_session + ) + assert len(results) == 0 diff --git a/ai-hub/tests/db/test_models.py b/ai-hub/tests/db/test_models.py index 20173e9..71dbcec 100644 --- a/ai-hub/tests/db/test_models.py +++ b/ai-hub/tests/db/test_models.py @@ -132,8 +132,8 @@ # Create vector metadata linked to the document vector_meta = VectorMetadata( + id=123, document_id=new_document.id, - faiss_index=123, embedding_model="test-model" ) db_session.add(vector_meta) @@ -144,7 +144,7 @@ retrieved_doc = db_session.query(Document).filter(Document.id == new_document.id).first() assert retrieved_doc is not None assert retrieved_doc.vector_metadata is not None - assert retrieved_doc.vector_metadata.faiss_index == 123 + assert retrieved_doc.vector_metadata.id == 123 def test_cascade_delete_session_and_messages(db_session): @@ -184,7 +184,7 @@ db_session.commit() db_session.refresh(new_document) - vector_meta = VectorMetadata(document_id=new_document.id, faiss_index=999, embedding_model="test") + vector_meta = VectorMetadata(id=999, document_id=new_document.id, embedding_model="test") db_session.add(vector_meta) db_session.commit() db_session.refresh(vector_meta) diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 9d7e916..6dd2b4a 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -333,14 +333,14 @@ assert response.json()["detail"] == "Document with ID 999 not found." mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999) -@patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index_and_metadata') +@patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index') @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.app.print_config') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('os.path.exists', return_value=True) @patch('faiss.read_index') -def test_shutdown_saves_index_and_metadata(mock_read_index, mock_os_exists, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index_and_metadata): +def test_shutdown_saves_index(mock_read_index, mock_os_exists, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index): """ Tests that the FAISS index and its associated metadata are saved on application shutdown. """ @@ -372,4 +372,4 @@ # Assert # Check that the new save_index_and_metadata method was called exactly once. - mock_save_index_and_metadata.assert_called_once() + mock_save_index.assert_called_once()