diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 344e779..517278d 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -6,7 +6,8 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 344e779..517278d 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -6,7 +6,8 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 470780d..65490e8 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -10,7 +10,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_session(): diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 344e779..517278d 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -6,7 +6,8 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 470780d..65490e8 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -10,7 +10,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_session(): diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index d1b8c62..4e44c4b 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -7,7 +7,7 @@ # Import the pipeline and its new signature from app.core.pipelines.dspy_rag import DspyRagPipeline, AnswerWithHistory from app.db import models # Import your SQLAlchemy models for mocking history -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_lm_configured(): diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 344e779..517278d 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -6,7 +6,8 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 470780d..65490e8 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -10,7 +10,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_session(): diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index d1b8c62..4e44c4b 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -7,7 +7,7 @@ # Import the pipeline and its new signature from app.core.pipelines.dspy_rag import DspyRagPipeline, AnswerWithHistory from app.db import models # Import your SQLAlchemy models for mocking history -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_lm_configured(): diff --git a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py new file mode 100644 index 0000000..695d5ae --- /dev/null +++ b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py @@ -0,0 +1,127 @@ +import pytest +from sqlalchemy import create_engine, Column, Integer, String, ForeignKey +from sqlalchemy.orm import sessionmaker, declarative_base, relationship +from typing import List + +# Mock the required models and FaissVectorStore for testing purposes +Base = declarative_base() + +class Document(Base): + __tablename__ = "documents" + id = Column(Integer, primary_key=True, index=True) + title = Column(String) # Add the missing columns + text = Column(String) + source_url = Column(String) + author = Column(String) + status = Column(String) + created_at = Column(String) + user_id = Column(Integer) + vectors = relationship("VectorMetadata", back_populates="document") + +class VectorMetadata(Base): + __tablename__ = "vector_metadata" + id = Column(Integer, primary_key=True, index=True) + faiss_index = Column(Integer, unique=True) + document_id = Column(Integer, ForeignKey("documents.id")) + document = relationship("Document", back_populates="vectors") + +# A mock version of the FaissVectorStore for testing +class FaissVectorStore: + def __init__(self, encoder): + self._index = {} + self._next_id = 0 + + def add_document_to_index(self, text: str) -> int: + self._index[self._next_id] = text + current_id = self._next_id + self._next_id += 1 + return current_id + + def search_similar_documents(self, query: str, k: int = 1) -> List[int]: + # This is a mock; in a real scenario, this would perform a vector search. + # Here, we'll assume it returns the IDs we expect for the test. + # This method is often patched in the tests. + return [0] + +# --- E2E test setup and fixtures --- +SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="function") +def db_session(): + """Provides a clean database session for each test.""" + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + +@pytest.fixture(scope="function") +def faiss_store(): + """Provides a fresh FaissVectorStore instance for each test.""" + class MockEncoder: + def encode(self, text): + return [1.0] * 768 + + return FaissVectorStore(MockEncoder()) + +# --- E2E test cases --- +# Assuming FaissDBRetriever and its dependencies are correctly imported +# You need to make sure the import path is correct for your project structure +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever + +def test_retrieve_context_successful(db_session, faiss_store): + """ + Tests that the retriever successfully finds and returns the correct + document text based on a FAISS search. + """ + # 1. Setup - Create documents and vectors + doc1 = Document( + title="Sample Title", # Provide values for the new columns + text="The quick brown fox jumps over the lazy dog." + ) + db_session.add(doc1) + db_session.commit() + db_session.refresh(doc1) + + # Mock the FAISS store to return a predictable ID + faiss_id1 = 123 + faiss_store.add_document_to_index = lambda text: faiss_id1 + + vec_meta1 = VectorMetadata(document_id=doc1.id, faiss_index=faiss_id1) + db_session.add(vec_meta1) + db_session.commit() + + # 2. Execution - Create and run the retriever + retriever = FaissDBRetriever(vector_store=faiss_store) + + # We'll mock the search to return the ID of our specific document + faiss_store.search_similar_documents = lambda query, k: [faiss_id1] + retrieved_context = retriever.retrieve_context(query="query for fox", db=db_session) + + # 3. Assertion - Verify the result + assert len(retrieved_context) == 1 + assert retrieved_context[0] == "The quick brown fox jumps over the lazy dog." + +def test_retrieve_context_no_match(db_session, faiss_store): + """ + Tests that the retriever returns an empty list when no matching + documents are found in the FAISS index. + """ + # 1. Setup - No documents or vectors are added to the database. + + # 2. Execution - Create and run the retriever + retriever = FaissDBRetriever(vector_store=faiss_store) + + # Mock the search to return an empty list + faiss_store.search_similar_documents = lambda query, k: [] + + retrieved_context = retriever.retrieve_context(query="non-existent query", db=db_session) + + # 3. Assertion - Verify the result is an empty list + assert retrieved_context == [] \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 344e779..517278d 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -6,7 +6,8 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 470780d..65490e8 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -10,7 +10,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_session(): diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index d1b8c62..4e44c4b 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -7,7 +7,7 @@ # Import the pipeline and its new signature from app.core.pipelines.dspy_rag import DspyRagPipeline, AnswerWithHistory from app.db import models # Import your SQLAlchemy models for mocking history -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_lm_configured(): diff --git a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py new file mode 100644 index 0000000..695d5ae --- /dev/null +++ b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py @@ -0,0 +1,127 @@ +import pytest +from sqlalchemy import create_engine, Column, Integer, String, ForeignKey +from sqlalchemy.orm import sessionmaker, declarative_base, relationship +from typing import List + +# Mock the required models and FaissVectorStore for testing purposes +Base = declarative_base() + +class Document(Base): + __tablename__ = "documents" + id = Column(Integer, primary_key=True, index=True) + title = Column(String) # Add the missing columns + text = Column(String) + source_url = Column(String) + author = Column(String) + status = Column(String) + created_at = Column(String) + user_id = Column(Integer) + vectors = relationship("VectorMetadata", back_populates="document") + +class VectorMetadata(Base): + __tablename__ = "vector_metadata" + id = Column(Integer, primary_key=True, index=True) + faiss_index = Column(Integer, unique=True) + document_id = Column(Integer, ForeignKey("documents.id")) + document = relationship("Document", back_populates="vectors") + +# A mock version of the FaissVectorStore for testing +class FaissVectorStore: + def __init__(self, encoder): + self._index = {} + self._next_id = 0 + + def add_document_to_index(self, text: str) -> int: + self._index[self._next_id] = text + current_id = self._next_id + self._next_id += 1 + return current_id + + def search_similar_documents(self, query: str, k: int = 1) -> List[int]: + # This is a mock; in a real scenario, this would perform a vector search. + # Here, we'll assume it returns the IDs we expect for the test. + # This method is often patched in the tests. + return [0] + +# --- E2E test setup and fixtures --- +SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="function") +def db_session(): + """Provides a clean database session for each test.""" + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + +@pytest.fixture(scope="function") +def faiss_store(): + """Provides a fresh FaissVectorStore instance for each test.""" + class MockEncoder: + def encode(self, text): + return [1.0] * 768 + + return FaissVectorStore(MockEncoder()) + +# --- E2E test cases --- +# Assuming FaissDBRetriever and its dependencies are correctly imported +# You need to make sure the import path is correct for your project structure +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever + +def test_retrieve_context_successful(db_session, faiss_store): + """ + Tests that the retriever successfully finds and returns the correct + document text based on a FAISS search. + """ + # 1. Setup - Create documents and vectors + doc1 = Document( + title="Sample Title", # Provide values for the new columns + text="The quick brown fox jumps over the lazy dog." + ) + db_session.add(doc1) + db_session.commit() + db_session.refresh(doc1) + + # Mock the FAISS store to return a predictable ID + faiss_id1 = 123 + faiss_store.add_document_to_index = lambda text: faiss_id1 + + vec_meta1 = VectorMetadata(document_id=doc1.id, faiss_index=faiss_id1) + db_session.add(vec_meta1) + db_session.commit() + + # 2. Execution - Create and run the retriever + retriever = FaissDBRetriever(vector_store=faiss_store) + + # We'll mock the search to return the ID of our specific document + faiss_store.search_similar_documents = lambda query, k: [faiss_id1] + retrieved_context = retriever.retrieve_context(query="query for fox", db=db_session) + + # 3. Assertion - Verify the result + assert len(retrieved_context) == 1 + assert retrieved_context[0] == "The quick brown fox jumps over the lazy dog." + +def test_retrieve_context_no_match(db_session, faiss_store): + """ + Tests that the retriever returns an empty list when no matching + documents are found in the FAISS index. + """ + # 1. Setup - No documents or vectors are added to the database. + + # 2. Execution - Create and run the retriever + retriever = FaissDBRetriever(vector_store=faiss_store) + + # Mock the search to return an empty list + faiss_store.search_similar_documents = lambda query, k: [] + + retrieved_context = retriever.retrieve_context(query="non-existent query", db=db_session) + + # 3. Assertion - Verify the result is an empty list + assert retrieved_context == [] \ No newline at end of file diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py index 2fd4ab3..c431c4e 100644 --- a/ai-hub/tests/core/services/test_rag.py +++ b/ai-hub/tests/core/services/test_rag.py @@ -11,7 +11,8 @@ from app.db import models from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.mock import MockEmbedder -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever, Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index 18afe8f..e1f2be7 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -3,7 +3,7 @@ from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 94413c8..a747ae6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,7 +6,8 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index d5e44de..88d1915 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import LLMProvider diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py deleted file mode 100644 index d2c158b..0000000 --- a/ai-hub/app/core/retrievers.py +++ /dev/null @@ -1,65 +0,0 @@ -import abc -from typing import List, Dict -from sqlalchemy.orm import Session -from app.core.vector_store.faiss_store import FaissVectorStore -from app.db import models - -class Retriever(abc.ABC): - """ - Abstract base class for a Retriever. - - A retriever is a pluggable component that is responsible for fetching - relevant context for a given query from a specific data source. - """ - @abc.abstractmethod - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Fetches context for a given query. - - Args: - query (str): The user's query string. - db (Session): The database session. - - Returns: - List[str]: A list of text strings representing the retrieved context. - """ - raise NotImplementedError - -class FaissDBRetriever(Retriever): - """ - A concrete retriever that uses a FAISS index and a local database - to find and return relevant document text. - """ - def __init__(self, vector_store: FaissVectorStore): - self.vector_store = vector_store - - def retrieve_context(self, query: str, db: Session) -> List[str]: - """ - Retrieves document text by first searching the FAISS index - and then fetching the corresponding documents from the database. - """ - faiss_ids = self.vector_store.search_similar_documents(query, k=3) - context_docs_text = [] - - if faiss_ids: - # Use FAISS IDs to find the corresponding document_id from the database - document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( - models.VectorMetadata.faiss_index.in_(faiss_ids) - ).all() - - document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] - - # Retrieve the full documents from the Document table - context_docs = db.query(models.Document).filter( - models.Document.id.in_(document_ids) - ).all() - - context_docs_text = [doc.text for doc in context_docs] - - return context_docs_text - -# You could add other retriever implementations here, like: -# class RemoteServiceRetriever(Retriever): -# def retrieve_context(self, query: str, db: Session) -> List[str]: -# # Logic to call a remote API and return context -# ... diff --git a/ai-hub/app/core/retrievers/__init__.py b/ai-hub/app/core/retrievers/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/retrievers/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/retrievers/base_retriever.py b/ai-hub/app/core/retrievers/base_retriever.py new file mode 100644 index 0000000..b902132 --- /dev/null +++ b/ai-hub/app/core/retrievers/base_retriever.py @@ -0,0 +1,24 @@ +import abc +from typing import List +from sqlalchemy.orm import Session + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError \ No newline at end of file diff --git a/ai-hub/app/core/retrievers/faiss_db_retriever.py b/ai-hub/app/core/retrievers/faiss_db_retriever.py new file mode 100644 index 0000000..3c3c202 --- /dev/null +++ b/ai-hub/app/core/retrievers/faiss_db_retriever.py @@ -0,0 +1,38 @@ +from typing import List +from sqlalchemy.orm import Session +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers.base_retriever import Retriever + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py deleted file mode 100644 index a5694dc..0000000 --- a/ai-hub/app/core/services.py +++ /dev/null @@ -1,141 +0,0 @@ -# import asyncio -# from typing import List, Dict, Any, Tuple -# from sqlalchemy.orm import Session, joinedload -# from sqlalchemy.exc import SQLAlchemyError -# import dspy - -# from app.core.vector_store.faiss_store import FaissVectorStore -# from app.core.vector_store.embedder.mock import MockEmbedder -# from app.db import models -# from app.core.retrievers import Retriever, FaissDBRetriever -# from app.core.llm_providers import get_llm_provider -# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline - -# class RAGService: -# """ -# Service class for managing documents and conversational RAG sessions. -# This class is now more robust and can handle both real and mock embedders -# by inspecting its dependencies. -# """ -# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): -# self.vector_store = vector_store -# self.retrievers = retrievers - -# # Assume one of the retrievers is the FAISS retriever, and you can access it. -# # A better approach might be to have a dictionary of named retrievers. -# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - -# # Store the embedder from the vector store for dynamic naming -# self.embedder = self.vector_store.embedder - - -# # --- Session Management --- - -# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: -# """Creates a new chat session in the database.""" -# try: -# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") -# db.add(new_session) -# db.commit() -# db.refresh(new_session) -# return new_session -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# async def chat_with_rag( -# self, -# db: Session, -# session_id: int, -# prompt: str, -# model: str, -# load_faiss_retriever: bool = False -# ) -> Tuple[str, str]: -# """ -# Handles a message within a session, including saving history and getting a response. -# Allows switching the LLM model and conditionally using the FAISS retriever. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# if not session: -# raise ValueError(f"Session with ID {session_id} not found.") - -# user_message = models.Message(session_id=session_id, sender="user", content=prompt) -# db.add(user_message) -# db.commit() - -# llm_provider = get_llm_provider(model) -# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) -# dspy.configure(lm=dspy_llm) - -# current_retrievers = [] -# if load_faiss_retriever: -# if self.faiss_retriever: -# current_retrievers.append(self.faiss_retriever) -# else: -# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - -# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - -# answer_text = await rag_pipeline.forward( -# question=prompt, -# history=session.messages, -# db=db -# ) - -# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) -# db.add(assistant_message) -# db.commit() - -# return answer_text, model - -# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: -# """ -# Retrieves all messages for a given session, or None if the session doesn't exist. -# """ -# session = db.query(models.Session).options( -# joinedload(models.Session.messages) -# ).filter(models.Session.id == session_id).first() - -# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - -# # --- Document Management (Updated) --- -# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: -# try: -# document_db = models.Document(**doc_data) -# db.add(document_db) -# db.commit() -# db.refresh(document_db) - -# # Use the embedder provided to the vector store to get the correct model name -# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - -# faiss_index = self.vector_store.add_document(document_db.text) -# vector_metadata = models.VectorMetadata( -# document_id=document_db.id, -# faiss_index=faiss_index, -# embedding_model=embedding_model_name -# ) -# db.add(vector_metadata) -# db.commit() -# return document_db.id -# except SQLAlchemyError as e: -# db.rollback() -# raise - -# def get_all_documents(self, db: Session) -> List[models.Document]: -# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - -# def delete_document(self, db: Session, document_id: int) -> int: -# try: -# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() -# if not doc_to_delete: -# return None -# db.delete(doc_to_delete) -# db.commit() -# return document_id -# except SQLAlchemyError as e: -# db.rollback() -# raise diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 344e779..517278d 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -6,7 +6,8 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever +from app.core.retrievers.base_retriever import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 470780d..65490e8 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -10,7 +10,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_session(): diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index d1b8c62..4e44c4b 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -7,7 +7,7 @@ # Import the pipeline and its new signature from app.core.pipelines.dspy_rag import DspyRagPipeline, AnswerWithHistory from app.db import models # Import your SQLAlchemy models for mocking history -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever @pytest.fixture def mock_lm_configured(): diff --git a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py new file mode 100644 index 0000000..695d5ae --- /dev/null +++ b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py @@ -0,0 +1,127 @@ +import pytest +from sqlalchemy import create_engine, Column, Integer, String, ForeignKey +from sqlalchemy.orm import sessionmaker, declarative_base, relationship +from typing import List + +# Mock the required models and FaissVectorStore for testing purposes +Base = declarative_base() + +class Document(Base): + __tablename__ = "documents" + id = Column(Integer, primary_key=True, index=True) + title = Column(String) # Add the missing columns + text = Column(String) + source_url = Column(String) + author = Column(String) + status = Column(String) + created_at = Column(String) + user_id = Column(Integer) + vectors = relationship("VectorMetadata", back_populates="document") + +class VectorMetadata(Base): + __tablename__ = "vector_metadata" + id = Column(Integer, primary_key=True, index=True) + faiss_index = Column(Integer, unique=True) + document_id = Column(Integer, ForeignKey("documents.id")) + document = relationship("Document", back_populates="vectors") + +# A mock version of the FaissVectorStore for testing +class FaissVectorStore: + def __init__(self, encoder): + self._index = {} + self._next_id = 0 + + def add_document_to_index(self, text: str) -> int: + self._index[self._next_id] = text + current_id = self._next_id + self._next_id += 1 + return current_id + + def search_similar_documents(self, query: str, k: int = 1) -> List[int]: + # This is a mock; in a real scenario, this would perform a vector search. + # Here, we'll assume it returns the IDs we expect for the test. + # This method is often patched in the tests. + return [0] + +# --- E2E test setup and fixtures --- +SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="function") +def db_session(): + """Provides a clean database session for each test.""" + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + +@pytest.fixture(scope="function") +def faiss_store(): + """Provides a fresh FaissVectorStore instance for each test.""" + class MockEncoder: + def encode(self, text): + return [1.0] * 768 + + return FaissVectorStore(MockEncoder()) + +# --- E2E test cases --- +# Assuming FaissDBRetriever and its dependencies are correctly imported +# You need to make sure the import path is correct for your project structure +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever + +def test_retrieve_context_successful(db_session, faiss_store): + """ + Tests that the retriever successfully finds and returns the correct + document text based on a FAISS search. + """ + # 1. Setup - Create documents and vectors + doc1 = Document( + title="Sample Title", # Provide values for the new columns + text="The quick brown fox jumps over the lazy dog." + ) + db_session.add(doc1) + db_session.commit() + db_session.refresh(doc1) + + # Mock the FAISS store to return a predictable ID + faiss_id1 = 123 + faiss_store.add_document_to_index = lambda text: faiss_id1 + + vec_meta1 = VectorMetadata(document_id=doc1.id, faiss_index=faiss_id1) + db_session.add(vec_meta1) + db_session.commit() + + # 2. Execution - Create and run the retriever + retriever = FaissDBRetriever(vector_store=faiss_store) + + # We'll mock the search to return the ID of our specific document + faiss_store.search_similar_documents = lambda query, k: [faiss_id1] + retrieved_context = retriever.retrieve_context(query="query for fox", db=db_session) + + # 3. Assertion - Verify the result + assert len(retrieved_context) == 1 + assert retrieved_context[0] == "The quick brown fox jumps over the lazy dog." + +def test_retrieve_context_no_match(db_session, faiss_store): + """ + Tests that the retriever returns an empty list when no matching + documents are found in the FAISS index. + """ + # 1. Setup - No documents or vectors are added to the database. + + # 2. Execution - Create and run the retriever + retriever = FaissDBRetriever(vector_store=faiss_store) + + # Mock the search to return an empty list + faiss_store.search_similar_documents = lambda query, k: [] + + retrieved_context = retriever.retrieve_context(query="non-existent query", db=db_session) + + # 3. Assertion - Verify the result is an empty list + assert retrieved_context == [] \ No newline at end of file diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py index 2fd4ab3..c431c4e 100644 --- a/ai-hub/tests/core/services/test_rag.py +++ b/ai-hub/tests/core/services/test_rag.py @@ -11,7 +11,8 @@ from app.db import models from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.mock import MockEmbedder -from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.retrievers.faiss_db_retriever import FaissDBRetriever, Retriever +from app.core.retrievers.base_retriever import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline from app.core.llm_providers import LLMProvider diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 1d7a96d..73eb307 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -11,7 +11,7 @@ from app.app import create_app from app.api.dependencies import get_db, ServiceContainer from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers.base_retriever import Retriever # Define a constant for the dimension to ensure consistency TEST_DIMENSION = 768