import logging
import re
from typing import List, Tuple, Optional

logger = logging.getLogger(__name__)
from sqlalchemy.orm import Session, joinedload

from app.db import models
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.factory import get_llm_provider
from app.core.orchestration import Architect
from app.core.orchestration.profiles import get_profile

class RAGService:
    """
    Service for orchestrating conversational RAG pipelines.
    Manages chat interactions and message history for a session.
    """
    def __init__(self, retrievers: List[Retriever], prompt_service = None, tool_service = None, node_registry_service = None):
        self.retrievers = retrievers
        self.prompt_service = prompt_service
        self.tool_service = tool_service
        self.node_registry_service = node_registry_service
        self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)

    async def chat_with_rag(
        self,
        db: Session,
        session_id: int,
        prompt: str,
        provider_name: str,
        load_faiss_retriever: bool = False,
        user_service = None,
        user_id: Optional[str] = None
    ):
        """
        Processes a user prompt within a session, yields events in real-time, 
        and saves the chat history at the end.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            raise ValueError(f"Session with ID {session_id} not found.")

        # Save user message
        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        db.add(user_message)
        db.commit()
        db.refresh(user_message)

        # Auto-title the session from the very first user message
        if session.title in (None, "New Chat Session", ""):
            session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "")

        # Keep provider_name in sync
        if session.provider_name != provider_name:
            session.provider_name = provider_name
        
        db.commit()

        # Resolve provider
        llm_prefs = {}
        user = session.user
        if user and user.preferences:
            llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {})

        if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and user_service:
            system_prefs = user_service.get_system_settings(db)
            system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(provider_name, {})
            if system_provider_prefs:
                merged = system_provider_prefs.copy()
                if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v})
                llm_prefs = merged
            
        api_key_override = llm_prefs.get("api_key")
        model_name_override = llm_prefs.get("model", "")
            
        kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
        llm_provider = get_llm_provider(
            provider_name, 
            model_name=model_name_override, 
            api_key_override=api_key_override,
            **kwargs
        )
        
        context_chunks = []
        if load_faiss_retriever:
            if self.faiss_retriever:
               context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db))
        
        architect = Architect()
        
        tools = []
        if self.tool_service:
            tools = self.tool_service.get_available_tools(db, session.user_id, feature=session.feature_name)

        profile = get_profile(session.feature_name)
        mesh_context = ""
        if session.attached_node_ids and profile.include_mesh_context:
            nodes = db.query(models.AgentNode).filter(models.AgentNode.node_id.in_(session.attached_node_ids)).all()
            if nodes:
                mesh_context = "Attached Agent Nodes (Infrastructure):\n"
                for node in nodes:
                    mesh_context += f"- Node ID: {node.node_id}\n"
                    mesh_context += f"  Name: {node.display_name}\n"
                    mesh_context += f"  Description: {node.description or 'No description provided.'}\n"
                    mesh_context += f"  Status: {node.last_status}\n"
                    
                    caps = node.capabilities or {}
                    if caps.get("local_ip"):
                        mesh_context += f"  Local IP: {caps.get('local_ip')}\n"
                    if caps.get("arch"):
                        mesh_context += f"  Architecture: {caps['arch']} ({caps.get('os', 'unknown')})\n"
                    if caps.get("gpu") and caps["gpu"] != "none":
                        mesh_context += f"  GPU: {caps['gpu']}\n"

                    # Privilege level — critical for knowing whether to use sudo
                    # Values are stored as strings ("true"/"false") due to protobuf map<string,string>
                    is_root = caps.get("is_root")
                    has_sudo = caps.get("has_sudo")
                    if is_root == "true" or is_root is True:
                        mesh_context += f"  Privilege Level: root (skip sudo — run all commands directly)\n"
                    elif has_sudo == "true" or has_sudo is True:
                        mesh_context += f"  Privilege Level: standard user with passwordless sudo\n"
                    elif is_root == "false" or is_root is False:
                        mesh_context += f"  Privilege Level: standard user (sudo NOT available — avoid privileged ops)\n"
                    # If neither field exists yet (old node version), omit to avoid confusion

                    shell_config = (node.skill_config or {}).get("shell", {})
                    if shell_config.get("enabled"):
                        sandbox = shell_config.get("sandbox") or {}
                        mode = sandbox.get("mode", "PERMISSIVE")
                        allowed = sandbox.get("allowed_commands", [])
                        denied = sandbox.get("denied_commands", [])
                        
                        mesh_context += f"  Terminal Sandbox Mode: {mode}\n"
                        if mode == "STRICT":
                            mesh_context += f"  AI Permitted Commands (Allow-list): {', '.join(allowed) if allowed else 'None'}\n"
                        elif mode == "PERMISSIVE":
                            mesh_context += f"  AI Restricted Commands (Blacklist): {', '.join(denied) if denied else 'None'}\n"
                        
                        if mode == "STRICT" and not allowed:
                             mesh_context += "  ⚠️ Warning: All shell commands are currently blocked by sandbox policy.\n"

                    # AI Visibility: Recent terminal history
                    registry = getattr(self, "node_registry_service", None)
                    if not registry and user_service:
                        registry = getattr(user_service, "node_registry_service", None)
                    
                    if registry:
                        live = registry.get_node(node.node_id)
                        if live and live.terminal_history:
                            # Grab recent chunks and join
                            history_blob = "".join(live.terminal_history[-40:])
                            
                            # Extreme Sanity Check: Strip ANSI again just in case, and limit total size
                            ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
                            clean_history = ansi_escape.sub('', history_blob)
                            
                            # Limit to 2000 chars to avoid bloating the context / breaking LLMs
                            if len(clean_history) > 2000:
                                clean_history = "...[truncated]...\n" + clean_history[-2000:]
                            
                            mesh_context += "  Recent Terminal Output:\n"
                            mesh_context += "  ```\n"
                            mesh_context += f"  {clean_history}"
                            if not clean_history.endswith('\n'): mesh_context += "\n"
                            mesh_context += "  ```\n"
                mesh_context += "\n"
        
        logger.info(f"[RAG] Mesh Context gathered. Length: {len(mesh_context)} chars.")
        if mesh_context:
            logger.info(f"[RAG] Mesh Context excerpt: {mesh_context[:200]}...")

        # Accumulators for the DB save at the end
        full_answer = ""
        full_reasoning = ""
        
        # Stream from specialized Architect
        async for event in architect.run(
            question=prompt, 
            history=session.messages,
            context_chunks = context_chunks,
            llm_provider = llm_provider,
            prompt_service = self.prompt_service,
            tool_service = self.tool_service,
            tools = tools,
            mesh_context = mesh_context,
            db = db,
            user_id = user_id or session.user_id,
            sync_workspace_id = session.sync_workspace_id,
            session_id = session_id,
            feature_name = session.feature_name,
            prompt_slug = profile.default_prompt_slug
        ):
            if event["type"] == "content":
                full_answer += event["content"]
            elif event["type"] == "reasoning":
                full_reasoning += event["content"]
            
            # Forward the event to the API stream
            yield event

        # Save assistant's response to DB
        assistant_message = models.Message(
            session_id=session_id, 
            sender="assistant", 
            content=full_answer,
            # We assume your models.Message might have these or we just save content
        )
        # Optional: if model supports reasoning_content field
        if full_reasoning and hasattr(assistant_message, "reasoning_content"):
            assistant_message.reasoning_content = full_reasoning
            
        db.add(assistant_message)
        db.commit()
        db.refresh(assistant_message)

        # Yield a final finish event with metadata
        yield {
            "type": "finish", 
            "message_id": assistant_message.id, 
            "provider": provider_name,
            "full_answer": full_answer
        }


    def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
        """
        Retrieves all messages for a given session, ordered by creation time.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        return sorted(session.messages, key=lambda msg: msg.created_at) if session else None