diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 603b12f..2d79f0e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -24,8 +24,9 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" - url: str = "postgresql://user:password@localhost/ai_hub_db" + mode: str = "sqlite" # "sqlite" or "postgresql" + url: Optional[str] = None # Used if mode != "sqlite" + local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" @@ -83,20 +84,31 @@ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ config_from_pydantic.database.mode - + + # Get local path for SQLite, from env/yaml/pydantic + local_db_path = os.getenv("LOCAL_DB_PATH") or \ + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path + + # Get external DB URL, from env/yaml/pydantic + external_db_url = os.getenv("DATABASE_URL") or \ + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url + if self.DB_MODE == "sqlite": - self.DATABASE_URL: str = "sqlite:///./data/ai_hub.db" + # Ensure path does not have duplicate ./ prefix + normalized_path = local_db_path.lstrip("./") + self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided - self.DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") - # Removed the ValueError here to allow tests to run + # --- API Keys --- + self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ @@ -116,7 +128,6 @@ self.EMBEDDING_DIMENSION: int = int(dimension_str) # New embedding provider settings - # Convert the environment variable value to lowercase to match the enum embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() @@ -135,5 +146,6 @@ self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Instantiate the single settings object for the application settings = Settings() diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 603b12f..2d79f0e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -24,8 +24,9 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" - url: str = "postgresql://user:password@localhost/ai_hub_db" + mode: str = "sqlite" # "sqlite" or "postgresql" + url: Optional[str] = None # Used if mode != "sqlite" + local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" @@ -83,20 +84,31 @@ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ config_from_pydantic.database.mode - + + # Get local path for SQLite, from env/yaml/pydantic + local_db_path = os.getenv("LOCAL_DB_PATH") or \ + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path + + # Get external DB URL, from env/yaml/pydantic + external_db_url = os.getenv("DATABASE_URL") or \ + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url + if self.DB_MODE == "sqlite": - self.DATABASE_URL: str = "sqlite:///./data/ai_hub.db" + # Ensure path does not have duplicate ./ prefix + normalized_path = local_db_path.lstrip("./") + self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided - self.DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") - # Removed the ValueError here to allow tests to run + # --- API Keys --- + self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ @@ -116,7 +128,6 @@ self.EMBEDDING_DIMENSION: int = int(dimension_str) # New embedding provider settings - # Convert the environment variable value to lowercase to match the enum embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() @@ -135,5 +146,6 @@ self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 6ee8d0d..11b1c0a 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -10,6 +10,10 @@ # for a remote server (requires DATABASE_URL to be set). mode: "sqlite" + # When using SQLite mode, specify the local database file path here. + # This path is relative to the project root and defaults to "./data/ai_hub.db". + local_path: "data/ai_hub.db" + llm_providers: # The default model name for the DeepSeek LLM provider. deepseek_model_name: "deepseek-chat" diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 603b12f..2d79f0e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -24,8 +24,9 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" - url: str = "postgresql://user:password@localhost/ai_hub_db" + mode: str = "sqlite" # "sqlite" or "postgresql" + url: Optional[str] = None # Used if mode != "sqlite" + local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" @@ -83,20 +84,31 @@ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ config_from_pydantic.database.mode - + + # Get local path for SQLite, from env/yaml/pydantic + local_db_path = os.getenv("LOCAL_DB_PATH") or \ + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path + + # Get external DB URL, from env/yaml/pydantic + external_db_url = os.getenv("DATABASE_URL") or \ + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url + if self.DB_MODE == "sqlite": - self.DATABASE_URL: str = "sqlite:///./data/ai_hub.db" + # Ensure path does not have duplicate ./ prefix + normalized_path = local_db_path.lstrip("./") + self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided - self.DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") - # Removed the ValueError here to allow tests to run + # --- API Keys --- + self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ @@ -116,7 +128,6 @@ self.EMBEDDING_DIMENSION: int = int(dimension_str) # New embedding provider settings - # Convert the environment variable value to lowercase to match the enum embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() @@ -135,5 +146,6 @@ self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 6ee8d0d..11b1c0a 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -10,6 +10,10 @@ # for a remote server (requires DATABASE_URL to be set). mode: "sqlite" + # When using SQLite mode, specify the local database file path here. + # This path is relative to the project root and defaults to "./data/ai_hub.db". + local_path: "data/ai_hub.db" + llm_providers: # The default model name for the DeepSeek LLM provider. deepseek_model_name: "deepseek-chat" diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index 49cbe2c..d5e44de 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -1,14 +1,14 @@ import dspy import logging -from typing import List +from typing import List, Callable, Optional from types import SimpleNamespace from sqlalchemy.orm import Session -from app.db import models # Import your SQLAlchemy models +from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import LLMProvider -# (The DSPyLLMProvider class is unchanged) + class DSPyLLMProvider(dspy.BaseLM): def __init__(self, provider: LLMProvider, model_name: str, **kwargs): super().__init__(model=model_name) @@ -22,63 +22,89 @@ choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) return SimpleNamespace(choices=[choice]) -# --- 1. Update the Signature to include Chat History --- + class AnswerWithHistory(dspy.Signature): """Given the context and chat history, answer the user's question.""" - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") question = dspy.InputField() answer = dspy.OutputField() + class DspyRagPipeline(dspy.Module): """ - A conversational RAG pipeline that uses document context and chat history. + A flexible and extensible DSPy-based RAG pipeline with modular stages. """ - def __init__(self, retrievers: List[Retriever]): + + def __init__( + self, + retrievers: List[Retriever], + signature_class: dspy.Signature = AnswerWithHistory, + context_postprocessor: Optional[Callable[[List[str]], str]] = None, + history_formatter: Optional[Callable[[List[models.Message]], str]] = None, + response_postprocessor: Optional[Callable[[str], str]] = None, + ): super().__init__() self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) + self.generate_answer = dspy.Predict(signature_class) - # --- 2. Update the `forward` method to accept history --- + self.context_postprocessor = context_postprocessor or self._default_context_postprocessor + self.history_formatter = history_formatter or self._default_history_formatter + self.response_postprocessor = response_postprocessor + async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] + + # Step 1: Retrieve all document contexts + context_chunks = [] for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) + context_chunks.extend(retriever.retrieve_context(question, db)) - context_text = "\n\n".join(retrieved_contexts) or "No context provided." + context_text = self.context_postprocessor(context_chunks) - # --- 3. Format the chat history into a string --- - history_str = "\n".join( + # Step 2: Format history + history_text = self.history_formatter(history) + + # Step 3: Build final prompt + instruction = self.generate_answer.signature.__doc__ + full_prompt = self._build_prompt(instruction, context_text, history_text, question) + + logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") + + # Step 4: Generate response using LLM + lm = dspy.settings.lm + if lm is None: + raise RuntimeError("DSPy LM not configured.") + + response_obj = await lm.aforward(prompt=full_prompt) + raw_response = response_obj.choices[0].message.content + + # Step 5: Optional response postprocessing + if self.response_postprocessor: + return self.response_postprocessor(raw_response) + + return raw_response + + # Default context processor: concatenate chunks + def _default_context_postprocessor(self, contexts: List[str]) -> str: + return "\n\n".join(contexts) or "No context provided." + + # Default history formatter: simple speaker prefix + def _default_history_formatter(self, history: List[models.Message]) -> str: + return "\n".join( f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" for msg in history ) - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" + # Prompt builder + def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str: + return ( + f"{instruction.strip()}\n\n" f"---\n\n" - f"Context: {context_text}\n\n" + f"Context:\n{context.strip()}\n\n" f"---\n\n" - f"Chat History:\n{history_str}\n\n" + f"Chat History:\n{history.strip()}\n\n" f"---\n\n" - f"Human: {question}\n" + f"Human: {question.strip()}\n" f"Assistant:" ) - logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 603b12f..2d79f0e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -24,8 +24,9 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" - url: str = "postgresql://user:password@localhost/ai_hub_db" + mode: str = "sqlite" # "sqlite" or "postgresql" + url: Optional[str] = None # Used if mode != "sqlite" + local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" @@ -83,20 +84,31 @@ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ config_from_pydantic.database.mode - + + # Get local path for SQLite, from env/yaml/pydantic + local_db_path = os.getenv("LOCAL_DB_PATH") or \ + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path + + # Get external DB URL, from env/yaml/pydantic + external_db_url = os.getenv("DATABASE_URL") or \ + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url + if self.DB_MODE == "sqlite": - self.DATABASE_URL: str = "sqlite:///./data/ai_hub.db" + # Ensure path does not have duplicate ./ prefix + normalized_path = local_db_path.lstrip("./") + self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided - self.DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") - # Removed the ValueError here to allow tests to run + # --- API Keys --- + self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ @@ -116,7 +128,6 @@ self.EMBEDDING_DIMENSION: int = int(dimension_str) # New embedding provider settings - # Convert the environment variable value to lowercase to match the enum embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() @@ -135,5 +146,6 @@ self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 6ee8d0d..11b1c0a 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -10,6 +10,10 @@ # for a remote server (requires DATABASE_URL to be set). mode: "sqlite" + # When using SQLite mode, specify the local database file path here. + # This path is relative to the project root and defaults to "./data/ai_hub.db". + local_path: "data/ai_hub.db" + llm_providers: # The default model name for the DeepSeek LLM provider. deepseek_model_name: "deepseek-chat" diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index 49cbe2c..d5e44de 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -1,14 +1,14 @@ import dspy import logging -from typing import List +from typing import List, Callable, Optional from types import SimpleNamespace from sqlalchemy.orm import Session -from app.db import models # Import your SQLAlchemy models +from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import LLMProvider -# (The DSPyLLMProvider class is unchanged) + class DSPyLLMProvider(dspy.BaseLM): def __init__(self, provider: LLMProvider, model_name: str, **kwargs): super().__init__(model=model_name) @@ -22,63 +22,89 @@ choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) return SimpleNamespace(choices=[choice]) -# --- 1. Update the Signature to include Chat History --- + class AnswerWithHistory(dspy.Signature): """Given the context and chat history, answer the user's question.""" - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") question = dspy.InputField() answer = dspy.OutputField() + class DspyRagPipeline(dspy.Module): """ - A conversational RAG pipeline that uses document context and chat history. + A flexible and extensible DSPy-based RAG pipeline with modular stages. """ - def __init__(self, retrievers: List[Retriever]): + + def __init__( + self, + retrievers: List[Retriever], + signature_class: dspy.Signature = AnswerWithHistory, + context_postprocessor: Optional[Callable[[List[str]], str]] = None, + history_formatter: Optional[Callable[[List[models.Message]], str]] = None, + response_postprocessor: Optional[Callable[[str], str]] = None, + ): super().__init__() self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) + self.generate_answer = dspy.Predict(signature_class) - # --- 2. Update the `forward` method to accept history --- + self.context_postprocessor = context_postprocessor or self._default_context_postprocessor + self.history_formatter = history_formatter or self._default_history_formatter + self.response_postprocessor = response_postprocessor + async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] + + # Step 1: Retrieve all document contexts + context_chunks = [] for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) + context_chunks.extend(retriever.retrieve_context(question, db)) - context_text = "\n\n".join(retrieved_contexts) or "No context provided." + context_text = self.context_postprocessor(context_chunks) - # --- 3. Format the chat history into a string --- - history_str = "\n".join( + # Step 2: Format history + history_text = self.history_formatter(history) + + # Step 3: Build final prompt + instruction = self.generate_answer.signature.__doc__ + full_prompt = self._build_prompt(instruction, context_text, history_text, question) + + logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") + + # Step 4: Generate response using LLM + lm = dspy.settings.lm + if lm is None: + raise RuntimeError("DSPy LM not configured.") + + response_obj = await lm.aforward(prompt=full_prompt) + raw_response = response_obj.choices[0].message.content + + # Step 5: Optional response postprocessing + if self.response_postprocessor: + return self.response_postprocessor(raw_response) + + return raw_response + + # Default context processor: concatenate chunks + def _default_context_postprocessor(self, contexts: List[str]) -> str: + return "\n\n".join(contexts) or "No context provided." + + # Default history formatter: simple speaker prefix + def _default_history_formatter(self, history: List[models.Message]) -> str: + return "\n".join( f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" for msg in history ) - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" + # Prompt builder + def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str: + return ( + f"{instruction.strip()}\n\n" f"---\n\n" - f"Context: {context_text}\n\n" + f"Context:\n{context.strip()}\n\n" f"---\n\n" - f"Chat History:\n{history_str}\n\n" + f"Chat History:\n{history.strip()}\n\n" f"---\n\n" - f"Human: {question}\n" + f"Human: {question.strip()}\n" f"Assistant:" ) - logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 5174546..8e7a120 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -9,6 +9,10 @@ # You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' TEST_PATH=${1:-$DEFAULT_TEST_PATH} +DB_MODE=sqlite +export LOCAL_DB_PATH="data/integration_test_ai_hub.db" +export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" + echo "--- Starting AI Hub Server for Tests ---" # Start the uvicorn server in the background diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 603b12f..2d79f0e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -24,8 +24,9 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" - url: str = "postgresql://user:password@localhost/ai_hub_db" + mode: str = "sqlite" # "sqlite" or "postgresql" + url: Optional[str] = None # Used if mode != "sqlite" + local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" @@ -83,20 +84,31 @@ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ config_from_pydantic.database.mode - + + # Get local path for SQLite, from env/yaml/pydantic + local_db_path = os.getenv("LOCAL_DB_PATH") or \ + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path + + # Get external DB URL, from env/yaml/pydantic + external_db_url = os.getenv("DATABASE_URL") or \ + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url + if self.DB_MODE == "sqlite": - self.DATABASE_URL: str = "sqlite:///./data/ai_hub.db" + # Ensure path does not have duplicate ./ prefix + normalized_path = local_db_path.lstrip("./") + self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided - self.DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") - # Removed the ValueError here to allow tests to run + # --- API Keys --- + self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ @@ -116,7 +128,6 @@ self.EMBEDDING_DIMENSION: int = int(dimension_str) # New embedding provider settings - # Convert the environment variable value to lowercase to match the enum embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() @@ -135,5 +146,6 @@ self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 6ee8d0d..11b1c0a 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -10,6 +10,10 @@ # for a remote server (requires DATABASE_URL to be set). mode: "sqlite" + # When using SQLite mode, specify the local database file path here. + # This path is relative to the project root and defaults to "./data/ai_hub.db". + local_path: "data/ai_hub.db" + llm_providers: # The default model name for the DeepSeek LLM provider. deepseek_model_name: "deepseek-chat" diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index 49cbe2c..d5e44de 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -1,14 +1,14 @@ import dspy import logging -from typing import List +from typing import List, Callable, Optional from types import SimpleNamespace from sqlalchemy.orm import Session -from app.db import models # Import your SQLAlchemy models +from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import LLMProvider -# (The DSPyLLMProvider class is unchanged) + class DSPyLLMProvider(dspy.BaseLM): def __init__(self, provider: LLMProvider, model_name: str, **kwargs): super().__init__(model=model_name) @@ -22,63 +22,89 @@ choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) return SimpleNamespace(choices=[choice]) -# --- 1. Update the Signature to include Chat History --- + class AnswerWithHistory(dspy.Signature): """Given the context and chat history, answer the user's question.""" - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") question = dspy.InputField() answer = dspy.OutputField() + class DspyRagPipeline(dspy.Module): """ - A conversational RAG pipeline that uses document context and chat history. + A flexible and extensible DSPy-based RAG pipeline with modular stages. """ - def __init__(self, retrievers: List[Retriever]): + + def __init__( + self, + retrievers: List[Retriever], + signature_class: dspy.Signature = AnswerWithHistory, + context_postprocessor: Optional[Callable[[List[str]], str]] = None, + history_formatter: Optional[Callable[[List[models.Message]], str]] = None, + response_postprocessor: Optional[Callable[[str], str]] = None, + ): super().__init__() self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) + self.generate_answer = dspy.Predict(signature_class) - # --- 2. Update the `forward` method to accept history --- + self.context_postprocessor = context_postprocessor or self._default_context_postprocessor + self.history_formatter = history_formatter or self._default_history_formatter + self.response_postprocessor = response_postprocessor + async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] + + # Step 1: Retrieve all document contexts + context_chunks = [] for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) + context_chunks.extend(retriever.retrieve_context(question, db)) - context_text = "\n\n".join(retrieved_contexts) or "No context provided." + context_text = self.context_postprocessor(context_chunks) - # --- 3. Format the chat history into a string --- - history_str = "\n".join( + # Step 2: Format history + history_text = self.history_formatter(history) + + # Step 3: Build final prompt + instruction = self.generate_answer.signature.__doc__ + full_prompt = self._build_prompt(instruction, context_text, history_text, question) + + logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") + + # Step 4: Generate response using LLM + lm = dspy.settings.lm + if lm is None: + raise RuntimeError("DSPy LM not configured.") + + response_obj = await lm.aforward(prompt=full_prompt) + raw_response = response_obj.choices[0].message.content + + # Step 5: Optional response postprocessing + if self.response_postprocessor: + return self.response_postprocessor(raw_response) + + return raw_response + + # Default context processor: concatenate chunks + def _default_context_postprocessor(self, contexts: List[str]) -> str: + return "\n\n".join(contexts) or "No context provided." + + # Default history formatter: simple speaker prefix + def _default_history_formatter(self, history: List[models.Message]) -> str: + return "\n".join( f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" for msg in history ) - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" + # Prompt builder + def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str: + return ( + f"{instruction.strip()}\n\n" f"---\n\n" - f"Context: {context_text}\n\n" + f"Context:\n{context.strip()}\n\n" f"---\n\n" - f"Chat History:\n{history_str}\n\n" + f"Chat History:\n{history.strip()}\n\n" f"---\n\n" - f"Human: {question}\n" + f"Human: {question.strip()}\n" f"Assistant:" ) - logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 5174546..8e7a120 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -9,6 +9,10 @@ # You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' TEST_PATH=${1:-$DEFAULT_TEST_PATH} +DB_MODE=sqlite +export LOCAL_DB_PATH="data/integration_test_ai_hub.db" +export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" + echo "--- Starting AI Hub Server for Tests ---" # Start the uvicorn server in the background diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index 4677ca5..d1b8c62 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -53,7 +53,7 @@ expected_prompt = ( f"{instruction}\n\n" f"---\n\n" - f"Context: Context chunk 1.\n\n" + f"Context:\nContext chunk 1.\n\n" f"---\n\n" f"Chat History:\n{expected_history_str}\n\n" f"---\n\n" @@ -87,7 +87,7 @@ expected_prompt = ( f"{instruction}\n\n" f"---\n\n" - f"Context: No context provided.\n\n" + f"Context:\nNo context provided.\n\n" f"---\n\n" f"Chat History:\n\n\n" # History string is empty f"---\n\n" diff --git a/.gitignore b/.gitignore index 95f61e0..7782b54 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .env **/.env **/*.egg-info -**/faiss_index.bin -**/ai_hub.db -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +**.bin +**.db \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 603b12f..2d79f0e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -24,8 +24,9 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" - url: str = "postgresql://user:password@localhost/ai_hub_db" + mode: str = "sqlite" # "sqlite" or "postgresql" + url: Optional[str] = None # Used if mode != "sqlite" + local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" @@ -83,20 +84,31 @@ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ config_from_pydantic.database.mode - + + # Get local path for SQLite, from env/yaml/pydantic + local_db_path = os.getenv("LOCAL_DB_PATH") or \ + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path + + # Get external DB URL, from env/yaml/pydantic + external_db_url = os.getenv("DATABASE_URL") or \ + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url + if self.DB_MODE == "sqlite": - self.DATABASE_URL: str = "sqlite:///./data/ai_hub.db" + # Ensure path does not have duplicate ./ prefix + normalized_path = local_db_path.lstrip("./") + self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided - self.DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") - # Removed the ValueError here to allow tests to run + # --- API Keys --- + self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ @@ -116,7 +128,6 @@ self.EMBEDDING_DIMENSION: int = int(dimension_str) # New embedding provider settings - # Convert the environment variable value to lowercase to match the enum embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() @@ -135,5 +146,6 @@ self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 6ee8d0d..11b1c0a 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -10,6 +10,10 @@ # for a remote server (requires DATABASE_URL to be set). mode: "sqlite" + # When using SQLite mode, specify the local database file path here. + # This path is relative to the project root and defaults to "./data/ai_hub.db". + local_path: "data/ai_hub.db" + llm_providers: # The default model name for the DeepSeek LLM provider. deepseek_model_name: "deepseek-chat" diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index 49cbe2c..d5e44de 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -1,14 +1,14 @@ import dspy import logging -from typing import List +from typing import List, Callable, Optional from types import SimpleNamespace from sqlalchemy.orm import Session -from app.db import models # Import your SQLAlchemy models +from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import LLMProvider -# (The DSPyLLMProvider class is unchanged) + class DSPyLLMProvider(dspy.BaseLM): def __init__(self, provider: LLMProvider, model_name: str, **kwargs): super().__init__(model=model_name) @@ -22,63 +22,89 @@ choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) return SimpleNamespace(choices=[choice]) -# --- 1. Update the Signature to include Chat History --- + class AnswerWithHistory(dspy.Signature): """Given the context and chat history, answer the user's question.""" - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") question = dspy.InputField() answer = dspy.OutputField() + class DspyRagPipeline(dspy.Module): """ - A conversational RAG pipeline that uses document context and chat history. + A flexible and extensible DSPy-based RAG pipeline with modular stages. """ - def __init__(self, retrievers: List[Retriever]): + + def __init__( + self, + retrievers: List[Retriever], + signature_class: dspy.Signature = AnswerWithHistory, + context_postprocessor: Optional[Callable[[List[str]], str]] = None, + history_formatter: Optional[Callable[[List[models.Message]], str]] = None, + response_postprocessor: Optional[Callable[[str], str]] = None, + ): super().__init__() self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) + self.generate_answer = dspy.Predict(signature_class) - # --- 2. Update the `forward` method to accept history --- + self.context_postprocessor = context_postprocessor or self._default_context_postprocessor + self.history_formatter = history_formatter or self._default_history_formatter + self.response_postprocessor = response_postprocessor + async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] + + # Step 1: Retrieve all document contexts + context_chunks = [] for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) + context_chunks.extend(retriever.retrieve_context(question, db)) - context_text = "\n\n".join(retrieved_contexts) or "No context provided." + context_text = self.context_postprocessor(context_chunks) - # --- 3. Format the chat history into a string --- - history_str = "\n".join( + # Step 2: Format history + history_text = self.history_formatter(history) + + # Step 3: Build final prompt + instruction = self.generate_answer.signature.__doc__ + full_prompt = self._build_prompt(instruction, context_text, history_text, question) + + logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") + + # Step 4: Generate response using LLM + lm = dspy.settings.lm + if lm is None: + raise RuntimeError("DSPy LM not configured.") + + response_obj = await lm.aforward(prompt=full_prompt) + raw_response = response_obj.choices[0].message.content + + # Step 5: Optional response postprocessing + if self.response_postprocessor: + return self.response_postprocessor(raw_response) + + return raw_response + + # Default context processor: concatenate chunks + def _default_context_postprocessor(self, contexts: List[str]) -> str: + return "\n\n".join(contexts) or "No context provided." + + # Default history formatter: simple speaker prefix + def _default_history_formatter(self, history: List[models.Message]) -> str: + return "\n".join( f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" for msg in history ) - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" + # Prompt builder + def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str: + return ( + f"{instruction.strip()}\n\n" f"---\n\n" - f"Context: {context_text}\n\n" + f"Context:\n{context.strip()}\n\n" f"---\n\n" - f"Chat History:\n{history_str}\n\n" + f"Chat History:\n{history.strip()}\n\n" f"---\n\n" - f"Human: {question}\n" + f"Human: {question.strip()}\n" f"Assistant:" ) - logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 5174546..8e7a120 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -9,6 +9,10 @@ # You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' TEST_PATH=${1:-$DEFAULT_TEST_PATH} +DB_MODE=sqlite +export LOCAL_DB_PATH="data/integration_test_ai_hub.db" +export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" + echo "--- Starting AI Hub Server for Tests ---" # Start the uvicorn server in the background diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index 4677ca5..d1b8c62 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -53,7 +53,7 @@ expected_prompt = ( f"{instruction}\n\n" f"---\n\n" - f"Context: Context chunk 1.\n\n" + f"Context:\nContext chunk 1.\n\n" f"---\n\n" f"Chat History:\n{expected_history_str}\n\n" f"---\n\n" @@ -87,7 +87,7 @@ expected_prompt = ( f"{instruction}\n\n" f"---\n\n" - f"Context: No context provided.\n\n" + f"Context:\nNo context provided.\n\n" f"---\n\n" f"Chat History:\n\n\n" # History string is empty f"---\n\n" diff --git a/ai-hub/tests/test_config.py b/ai-hub/tests/test_config.py index eb7b80d..abea3b0 100644 --- a/ai-hub/tests/test_config.py +++ b/ai-hub/tests/test_config.py @@ -9,6 +9,7 @@ """ Creates a temporary config.yaml file and returns its path. Corrected the 'provider' value to be lowercase 'mock' to match the Enum. + Added database settings for testing. """ config_content = { "application": { @@ -20,6 +21,11 @@ # This value must be lowercase to match the Pydantic Enum member "provider": "mock", "model_name": "embedding-model-from-yaml" + }, + "database": { + "mode": "sqlite", + "local_path": "custom_folder/test_ai_hub.db", + "url": "postgresql://user:pass@host/dbname" # Should be ignored for sqlite mode } } config_path = tmp_path / "test_config.yaml" @@ -36,147 +42,87 @@ """ monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_deepseek_key") monkeypatch.setenv("GEMINI_API_KEY", "mock_gemini_key") - # Also set a default EMBEDDING_API_KEY for completeness monkeypatch.setenv("EMBEDDING_API_KEY", "mock_embedding_key") -def test_env_var_overrides_yaml(monkeypatch, tmp_config_file): - """Tests that an env var overrides YAML for DEEPSEEK_MODEL_NAME.""" +def test_sqlite_db_url_from_yaml(monkeypatch, tmp_config_file): + """Tests DATABASE_URL is constructed correctly from YAML sqlite local_path.""" monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.setenv("DEEPSEEK_MODEL_NAME", "deepseek-from-env") + monkeypatch.delenv("DB_MODE", raising=False) + monkeypatch.delenv("LOCAL_DB_PATH", raising=False) + monkeypatch.delenv("DATABASE_URL", raising=False) from app import config importlib.reload(config) - assert config.settings.DEEPSEEK_MODEL_NAME == "deepseek-from-env" + expected_path = "sqlite:///./custom_folder/test_ai_hub.db" + assert config.settings.DB_MODE == "sqlite" + assert config.settings.DATABASE_URL == expected_path -def test_env_var_overrides_default(monkeypatch): - """Tests that env var overrides hardcoded default when YAML is missing.""" - monkeypatch.setenv("CONFIG_PATH", "/path/that/does/not/exist.yaml") - monkeypatch.setenv("DEEPSEEK_MODEL_NAME", "deepseek-from-env") +def test_sqlite_db_url_from_env_local_path(monkeypatch, tmp_path): + """Tests that LOCAL_DB_PATH env var overrides YAML for sqlite DATABASE_URL.""" + monkeypatch.setenv("DB_MODE", "sqlite") + monkeypatch.setenv("LOCAL_DB_PATH", "env_folder/env_ai_hub.db") + monkeypatch.delenv("CONFIG_PATH", raising=False) + monkeypatch.delenv("DATABASE_URL", raising=False) from app import config importlib.reload(config) - assert config.settings.DEEPSEEK_MODEL_NAME == "deepseek-from-env" + expected_path = "sqlite:///./env_folder/env_ai_hub.db" + assert config.settings.DB_MODE == "sqlite" + assert config.settings.DATABASE_URL == expected_path -def test_hardcoded_default_is_used_last(monkeypatch): - """Tests fallback to hardcoded default if ENV and YAML missing.""" - monkeypatch.setenv("CONFIG_PATH", "/path/that/does/not/exist.yaml") - monkeypatch.delenv("DEEPSEEK_MODEL_NAME", raising=False) - - from app import config - importlib.reload(config) - - assert config.settings.DEEPSEEK_MODEL_NAME == "deepseek-chat" - - -# -------------------------- -# ✅ LOG_LEVEL TESTS -# -------------------------- - -def test_log_level_env_overrides_yaml(monkeypatch, tmp_config_file): - """Tests LOG_LEVEL: ENV > YAML > default.""" +def test_external_db_url_used_when_not_sqlite(monkeypatch, tmp_config_file): + """Tests DATABASE_URL uses external URL when DB_MODE is not sqlite.""" monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.setenv("LOG_LEVEL", "DEBUG") + monkeypatch.setenv("DB_MODE", "postgresql") + monkeypatch.setenv("DATABASE_URL", "postgresql://env_user:env_pass@env_host/env_db") + monkeypatch.delenv("LOCAL_DB_PATH", raising=False) from app import config importlib.reload(config) - assert config.settings.LOG_LEVEL == "DEBUG" + assert config.settings.DB_MODE == "postgresql" + assert config.settings.DATABASE_URL == "postgresql://env_user:env_pass@env_host/env_db" -def test_log_level_yaml_over_default(monkeypatch, tmp_config_file): - """Tests LOG_LEVEL uses YAML when ENV is not set.""" - monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.delenv("LOG_LEVEL", raising=False) +def test_external_db_url_from_yaml_when_not_sqlite(monkeypatch, tmp_path): + """Tests DATABASE_URL uses YAML url when DB_MODE != sqlite and no env DATABASE_URL.""" + # Write YAML with postgresql mode and url + config_content = { + "database": { + "mode": "postgresql", + "url": "postgresql://yaml_user:yaml_pass@yaml_host/yaml_db", + "local_path": "ignored_path_for_postgresql.db" + } + } + config_path = tmp_path / "test_config_pg.yaml" + with open(config_path, 'w') as f: + yaml.dump(config_content, f) + monkeypatch.setenv("CONFIG_PATH", str(config_path)) + monkeypatch.delenv("DATABASE_URL", raising=False) + monkeypatch.delenv("DB_MODE", raising=False) + monkeypatch.delenv("LOCAL_DB_PATH", raising=False) from app import config importlib.reload(config) - assert config.settings.LOG_LEVEL == "WARNING" + assert config.settings.DB_MODE == "postgresql" + assert config.settings.DATABASE_URL == "postgresql://yaml_user:yaml_pass@yaml_host/yaml_db" -def test_log_level_default_used(monkeypatch): - """Tests LOG_LEVEL falls back to default when neither ENV nor YAML set.""" - monkeypatch.setenv("CONFIG_PATH", "/does/not/exist.yaml") - monkeypatch.delenv("LOG_LEVEL", raising=False) +def test_sqlite_db_url_defaults(monkeypatch): + """Tests DATABASE_URL defaults to sqlite path if no env or YAML.""" + monkeypatch.setenv("DB_MODE", "sqlite") + monkeypatch.delenv("LOCAL_DB_PATH", raising=False) + monkeypatch.delenv("DATABASE_URL", raising=False) + monkeypatch.delenv("CONFIG_PATH", raising=False) from app import config importlib.reload(config) - assert config.settings.LOG_LEVEL == "INFO" - - -# -------------------------- -# ✅ EMBEDDING PROVIDER TESTS -# -------------------------- - -def test_embedding_provider_env_overrides_yaml(monkeypatch, tmp_config_file): - """Tests EMBEDDING_PROVIDER: ENV > YAML > default.""" - monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.setenv("EMBEDDING_PROVIDER", "GOOGLE_GENAI") - - from app import config - importlib.reload(config) - - assert config.settings.EMBEDDING_PROVIDER == EmbeddingProvider.GOOGLE_GENAI - - -def test_embedding_provider_yaml_overrides_default(monkeypatch, tmp_config_file): - """Tests EMBEDDING_PROVIDER uses YAML when ENV is not set.""" - monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.delenv("EMBEDDING_PROVIDER", raising=False) - - from app import config - importlib.reload(config) - - assert config.settings.EMBEDDING_PROVIDER == EmbeddingProvider.MOCK - - -def test_embedding_provider_default_used(monkeypatch): - """Tests EMBEDDING_PROVIDER falls back to default when neither ENV nor YAML set.""" - monkeypatch.setenv("CONFIG_PATH", "/does/not/exist.yaml") - monkeypatch.delenv("EMBEDDING_PROVIDER", raising=False) - - from app import config - importlib.reload(config) - - assert config.settings.EMBEDDING_PROVIDER == EmbeddingProvider.GOOGLE_GENAI - - -# -------------------------- -# ✅ EMBEDDING MODEL NAME TESTS -# -------------------------- - -def test_embedding_model_name_env_overrides_yaml(monkeypatch, tmp_config_file): - """Tests EMBEDDING_MODEL_NAME: ENV > YAML > default.""" - monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.setenv("EMBEDDING_MODEL_NAME", "embedding-model-from-env") - - from app import config - importlib.reload(config) - - assert config.settings.EMBEDDING_MODEL_NAME == "embedding-model-from-env" - - -def test_embedding_model_name_yaml_overrides_default(monkeypatch, tmp_config_file): - """Tests EMBEDDING_MODEL_NAME uses YAML when ENV is not set.""" - monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.delenv("EMBEDDING_MODEL_NAME", raising=False) - - from app import config - importlib.reload(config) - - assert config.settings.EMBEDDING_MODEL_NAME == "embedding-model-from-yaml" - - -def test_embedding_model_name_default_used(monkeypatch): - """Tests EMBEDDING_MODEL_NAME falls back to default when neither ENV nor YAML set.""" - monkeypatch.setenv("CONFIG_PATH", "/does/not/exist.yaml") - from app import config - importlib.reload(config) - - assert config.settings.EMBEDDING_MODEL_NAME == "models/text-embedding-004" + assert config.settings.DB_MODE == "sqlite" + assert config.settings.DATABASE_URL == "sqlite:///./data/ai_hub.db"