diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py index 77b0598..e10d3e1 100644 --- a/ai-hub/app/api/routes/user.py +++ b/ai-hub/app/api/routes/user.py @@ -69,14 +69,16 @@ # SECURITY: Prevent Open Redirect - Validate 'state' is a safe URL # Ideally this matches settings.FRONTEND_URL or a whitelist. safe_url = state - if not state.startswith(settings.OIDC_SERVER_URL) and "http" in state: - # Basic check: If it's an absolute URL, it must be on our frontend or server - # For now, we enforce that it doesn't leave the intended context. - # In production, this should be a rigorous host check. - if not state.startswith(request.base_url._url.split("/api")[0]): - logger.warning(f"Prevented potentially malicious open redirect to: {state}") - # Fallback to local admin as a safety valve, or raise 400 - safe_url = "/dashboard" + parsed_url = urllib.parse.urlparse(state) + if parsed_url.netloc: + # Absolute URL, verify domain + allowed_domains = ["ai.jerxie.com", "localhost", "127.0.0.1"] + api_domain = urllib.parse.urlparse(str(request.base_url)).netloc + allowed_domains.append(api_domain) + + if parsed_url.netloc not in allowed_domains: + logger.warning(f"Prevented potentially malicious open redirect to: {state}") + safe_url = "/dashboard" frontend_redirect_url = f"{safe_url}?user_id={user_id}" if linked: @@ -474,6 +476,7 @@ @router.get("/me/config/export", summary="Export Configurations to YAML") async def export_user_config_yaml( + reveal_secrets: bool = Query(False, description="Reveal secrets in export"), db: Session = Depends(get_db), user_id: str = Depends(get_current_user_id) ): @@ -489,6 +492,13 @@ from app.config import settings import yaml + def mask_secret(value): + if not value: + return None + if reveal_secrets: + return value + return "***" + llm_prefs = prefs_dict.get("llm", {}) tts_prefs = prefs_dict.get("tts", {}) stt_prefs = prefs_dict.get("stt", {}) @@ -498,15 +508,15 @@ if not user_providers: # Fallback to system defaults if no user config exists llm_providers_export = { - "deepseek_api_key": settings.DEEPSEEK_API_KEY, + "deepseek_api_key": mask_secret(settings.DEEPSEEK_API_KEY), "deepseek_model_name": settings.DEEPSEEK_MODEL_NAME, - "gemini_api_key": settings.GEMINI_API_KEY, + "gemini_api_key": mask_secret(settings.GEMINI_API_KEY), "gemini_model_name": settings.GEMINI_MODEL_NAME, - "openai_api_key": settings.OPENAI_API_KEY + "openai_api_key": mask_secret(settings.OPENAI_API_KEY) } else: for p, p_data in user_providers.items(): - llm_providers_export[f"{p}_api_key"] = p_data.get("api_key") + llm_providers_export[f"{p}_api_key"] = mask_secret(p_data.get("api_key")) llm_providers_export[f"{p}_model_name"] = p_data.get("model") def get_provider_export(section_prefs, fallback_provider, fallback_model, fallback_api_key, fallback_voice=None): @@ -518,21 +528,19 @@ "provider": active_p, "model_name": p_data.get("model"), "voice_name": p_data.get("voice"), - "api_key": p_data.get("api_key") + "api_key": mask_secret(p_data.get("api_key")) } # Fallback to system settings return { "provider": fallback_provider, "model_name": fallback_model, "voice_name": fallback_voice, - "api_key": fallback_api_key + "api_key": mask_secret(fallback_api_key) } # Layer 2 (Day 2) Export: Only LLM, TTS, STT yaml_data = { - "llm_providers": { - "providers": user_providers or settings.LLM_PROVIDERS - }, + "llm_providers": llm_providers_export, "tts_provider": get_provider_export(tts_prefs, settings.TTS_PROVIDER, settings.TTS_MODEL_NAME, settings.TTS_API_KEY, settings.TTS_VOICE_NAME), "stt_provider": get_provider_export(stt_prefs, settings.STT_PROVIDER, settings.STT_MODEL_NAME, settings.STT_API_KEY) } diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index c8587c6..42693a2 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -318,7 +318,7 @@ CORSMiddleware, allow_origins=cors_origins, allow_credentials=True, - allow_methods=["*"], # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.) - allow_headers=["*"], # Allows all headers + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization"], ) return app diff --git a/ai-hub/app/db/models/asset.py b/ai-hub/app/db/models/asset.py index af17e56..f0efc73 100644 --- a/ai-hub/app/db/models/asset.py +++ b/ai-hub/app/db/models/asset.py @@ -2,6 +2,32 @@ from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Boolean, JSON, LargeBinary from sqlalchemy.orm import relationship from ..database import Base +import json +import base64 +from hashlib import sha256 +from cryptography.fernet import Fernet +from app.config import settings +from sqlalchemy.types import TypeDecorator + +class EncryptedJSON(TypeDecorator): + impl = Text + + def process_bind_param(self, value, dialect): + if value is None: + return None + key = base64.urlsafe_b64encode(sha256(settings.SECRET_KEY.encode()).digest()) + fernet = Fernet(key) + data = json.dumps(value).encode() + encrypted = fernet.encrypt(data) + return encrypted.decode('utf-8') + + def process_result_value(self, value, dialect): + if value is None: + return None + key = base64.urlsafe_b64encode(sha256(settings.SECRET_KEY.encode()).digest()) + fernet = Fernet(key) + decrypted = fernet.decrypt(value.encode('utf-8')) + return json.loads(decrypted.decode('utf-8')) class PromptTemplate(Base): __tablename__ = 'prompt_templates' @@ -12,7 +38,7 @@ content = Column(Text, nullable=False) version = Column(Integer, default=1) - owner_id = Column(String, ForeignKey('users.id'), nullable=False) + owner_id = Column(String, ForeignKey('users.id'), nullable=False, index=True) group_id = Column(String, ForeignKey('groups.id'), nullable=True) is_public = Column(Boolean, default=False) @@ -84,7 +110,7 @@ id = Column(Integer, primary_key=True, index=True) name = Column(String, nullable=False) url = Column(String, nullable=False) - auth_config = Column(JSON, default={}, nullable=True) + auth_config = Column(EncryptedJSON, default={}, nullable=True) owner_id = Column(String, ForeignKey('users.id'), nullable=False) group_id = Column(String, ForeignKey('groups.id'), nullable=True) diff --git a/ai-hub/tests/db/test_models.py b/ai-hub/tests/db/test_models.py index 71dbcec..9edc256 100644 --- a/ai-hub/tests/db/test_models.py +++ b/ai-hub/tests/db/test_models.py @@ -199,3 +199,40 @@ # Check that the document is gone and the vector metadata has been cascaded assert db_session.query(Document).filter(Document.id == new_document.id).count() == 0 assert db_session.query(VectorMetadata).filter(VectorMetadata.document_id == new_document.id).count() == 0 + +def test_encrypted_json_works(db_session): + """Test that EncryptedJSON encrypts data in DB and decrypts on load.""" + from app.db.models import MCPServer, User + + # Create a user first (owner of the asset) + user = User(id="test_user_id", email="test@example.com", username="testuser") + db_session.add(user) + db_session.commit() + + # Create a test config + auth_config = {"api_key": "secret_api_key_123", "token": "abcde"} + + # Create MCPServer with encrypted config + server = MCPServer( + name="Test MCP Server", + url="http://mcp-server:8080", + auth_config=auth_config, + owner_id=user.id + ) + + db_session.add(server) + db_session.commit() + db_session.refresh(server) + + # Verify decryption on load + assert server.auth_config == auth_config + + # Verify encryption in DB (direct SQL query to check raw value) + from sqlalchemy import text + result = db_session.execute(text("SELECT auth_config FROM mcp_servers WHERE id = :id"), {"id": server.id}).fetchone() + + raw_value = result[0] + assert raw_value is not None + assert "secret_api_key_123" not in raw_value + assert "abcde" not in raw_value + diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 801860f..757b178 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -55,6 +55,31 @@ @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') +@patch('app.app.print_config') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +@patch('os.makedirs') +def test_cors_restricted(mock_makedirs, mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container): + """Test that CORS is restricted and not permissive.""" + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() + mock_services = MagicMock() + mock_service_container.return_value = mock_services + + app = create_app() + client = TestClient(app) + + response = client.get( + "/", + headers={ + "Origin": "http://evil.com", + } + ) + + assert "Access-Control-Allow-Origin" not in response.headers + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_create_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): @@ -375,3 +400,78 @@ # Assert # Check that the new save_index_and_metadata method was called exactly once. mock_save_index.assert_called_once() + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.app.print_config') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +@patch('os.makedirs') +def test_open_redirect_protection(mock_makedirs, mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container): + """Test that Open Redirect in OIDC callback is prevented.""" + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() + mock_services = MagicMock() + mock_service_container.return_value = mock_services + + # Mock auth_service.handle_callback + mock_services.auth_service.handle_callback = AsyncMock(return_value={"user_id": "test_user", "linked": False}) + + app = create_app() + client = TestClient(app) + + # Test with malicious state + response = client.get( + "/api/v1/users/login/callback?code=dummy_code&state=https://evil.com/malicious", + follow_redirects=False + ) + + assert response.status_code == 307 + location = response.headers.get("Location") + assert location is not None + assert "evil.com" not in location + assert "/dashboard" in location + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.app.print_config') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +@patch('os.makedirs') +def test_config_export_redacts_secrets(mock_makedirs, mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container): + """Test that config export redacts secrets.""" + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() + mock_services = MagicMock() + mock_service_container.return_value = mock_services + + # Mock user_service.get_user_by_id + mock_user = MagicMock() + mock_user.id = "admin_user" + mock_user.role = "admin" + mock_user.preferences = { + "llm": { + "providers": { + "custom_provider": {"api_key": "my_secret_key", "model": "gpt-4"} + } + } + } + mock_services.user_service.get_user_by_id.return_value = mock_user + + app = create_app() + client = TestClient(app) + + response = client.get( + "/api/v1/users/me/config/export", + headers={"X-User-ID": "admin_user"} + ) + + assert response.status_code == 200 + + import yaml + data = yaml.safe_load(response.text) + + providers = data.get("llm_providers", {}).get("providers", {}) + custom_provider = providers.get("custom_provider", {}) + assert custom_provider.get("api_key") == "***" +