import logging
import time
from typing import List, Optional, Dict, Any, AsyncGenerator, Tuple
from sqlalchemy.orm import Session, joinedload

from app.db import models
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.factory import get_llm_provider
from app.core.orchestration import Architect
from app.core.orchestration.profiles import get_profile
from app.core._regex import ANSI_ESCAPE
from app.db.session import async_db_op

logger = logging.getLogger(__name__)

class RAGService:
    """
    Orchestrates conversational RAG pipelines.
    Decomposed into manageable components for maintainability.
    """
    def __init__(self, retrievers: List[Retriever], prompt_service=None, tool_service=None, node_registry_service=None, services=None):
        self.retrievers = retrievers
        self.prompt_service = prompt_service
        self.tool_service = tool_service
        self.node_registry_service = node_registry_service
        self.services = services
        self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)

    async def chat_with_rag(
        self,
        db: Session,
        session_id: int,
        prompt: str,
        provider_name: str,
        load_faiss_retriever: bool = False,
        user_service=None,
        user_id: Optional[str] = None,
        save_prompt: bool = True
    ) -> AsyncGenerator[Dict[str, Any], None]:
        """Entry point for the RAG pipeline."""
        session = self._resolve_session(db, session_id, prompt, save_prompt=save_prompt)
        llm_provider, resolved_provider_name = self._resolve_provider(db, session, provider_name, user_service)
        
        context_chunks = []
        if load_faiss_retriever and self.faiss_retriever:
            context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db))
        
        mesh_context = self._gather_mesh_context(db, session, user_service)
        tools = self.tool_service.get_available_tools(db, session.user_id, feature=session.feature_name, session_id=session.id) if self.tool_service else []
        profile = get_profile(session.feature_name)

        # Accumulators
        state = {
            "answer": "", "reasoning": "", "tool_counts": {}, 
            "usage": {"input": 0, "output": 0}, "msg": None
        }

        architect = Architect()
        async for event in architect.run(
            question=prompt, history=session.messages, context_chunks=context_chunks,
            llm_provider=llm_provider, prompt_service=self.prompt_service, tool_service=self.tool_service,
            tools=tools, mesh_context=mesh_context, db=db, user_id=user_id or session.user_id,
            sync_workspace_id=session.sync_workspace_id or str(session_id), session_id=session_id,
            feature_name=session.feature_name, prompt_slug=profile.default_prompt_slug,
            session_override=session.system_prompt_override
        ):
            await self._process_event(db, session_id, event, state)
            yield event

        # Final persistence
        assistant_msg = await self._finalize_assistant_message(db, session_id, state)
        yield {
            "type": "finish", "message_id": assistant_msg.id, "provider": resolved_provider_name,
            "full_answer": state["answer"], "tool_counts": state["tool_counts"],
            "input_tokens": state["usage"]["input"], "output_tokens": state["usage"]["output"]
        }

    def _resolve_session(self, db: Session, session_id: int, prompt: str, save_prompt: bool = True) -> models.Session:
        """Fetches and initializes the session state."""
        session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first()
        if not session: raise ValueError(f"Session {session_id} not found.")

        # Save user message
        if save_prompt:
            db.add(models.Message(session_id=session_id, sender="user", content=prompt))
        if session.title in (None, "New Chat Session", ""):
            session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "")
        
        db.commit()
        return session

    def _resolve_provider(self, db: Session, session: models.Session, provider_name: str, user_service) -> Tuple[Any, str]:
        """Resolves LLM provider with user-preference and system-level fallbacks."""
        pref_svc = getattr(self.services, "preference_service", None) if self.services else None
        
        if not pref_svc:
             from app.core.providers.factory import get_llm_provider
             return get_llm_provider(provider_name, model_name=session.model_name), provider_name
             
        return pref_svc.resolve_llm_provider(
            db, session.user, provider_name, model_name=session.model_name
        )

    def _gather_mesh_context(self, db: Session, session: models.Session, user_service) -> str:
        """Aggregates technical infrastructure context from attached agent nodes."""
        profile = get_profile(session.feature_name)
        if not session.attached_node_ids or not profile.include_mesh_context:
            return ""

        nodes = db.query(models.AgentNode).filter(models.AgentNode.node_id.in_(session.attached_node_ids)).all()
        ctx = "Attached Agent Nodes (Infrastructure):\n"
        for node in nodes:
            ctx += f"- Node ID: {node.node_id}\n  Name: {node.display_name}\n"
            ctx += f"  Status: {node.last_status}\n"
            
            caps = node.capabilities or {}
            if caps.get("arch"): ctx += f"  Arch: {caps['arch']} ({caps.get('os', 'unknown')})\n"
            
            # Privilege inference
            is_root, has_sudo = caps.get("is_root") == "true", caps.get("has_sudo") == "true"
            ctx += f"  Privilege: {'root' if is_root else 'sudo-user' if has_sudo else 'standard'}\n"

            # Sandbox status
            sb = (node.skill_config or {}).get("shell", {}).get("sandbox", {})
            if sb: ctx += f"  Sandbox: {sb.get('mode', 'PERMISSIVE')}\n"

            # Live terminal tailing
            registry = self.node_registry_service or (user_service.node_registry_service if user_service else None)
            if registry:
                ctx += self._render_node_history(registry, node.node_id)
        
        return ctx

    def _render_node_history(self, registry, node_id: str) -> str:
        """Extracts and cleans the recent terminal history for a specific node."""
        live = registry.get_node(node_id)
        if not live or not live.terminal_history: return ""
        
        chunks, total_len = [], 0
        for chunk in reversed(list(live.terminal_history)[-40:]):
            c_str = chunk if isinstance(chunk, str) else chunk.get("output", str(chunk)) if isinstance(chunk, dict) else str(chunk)
            chunks.insert(0, c_str)
            total_len += len(c_str)
            if total_len > 4000: break
        
        clean = ANSI_ESCAPE.sub('', "".join(chunks))
        if len(clean) > 2000: clean = "...[truncated]...\n" + clean[-2000:]
        return f"  Recent Terminal Output:\n  ```\n  {clean}\n  ```\n"

    async def _process_event(self, db, session_id, event, state):
        """Updates internal state and DB progress based on pipeline events."""
        e_type = event["type"]
        if e_type == "content": state["answer"] += event["content"]
        elif e_type == "reasoning": state["reasoning"] += event["content"]
        elif e_type == "tool_start":
            name = event.get("name")
            if name: state["tool_counts"][name] = state["tool_counts"].get(name, {"calls":0, "successes":0, "failures":0}); state["tool_counts"][name]["calls"] += 1
        elif e_type == "tool_result":
            name, res = event.get("name"), event.get("result")
            if name and name in state["tool_counts"]:
                if res and (not isinstance(res, dict) or res.get("success") is False): state["tool_counts"][name]["failures"] += 1
                else: state["tool_counts"][name]["successes"] += 1
        elif e_type == "token_counted":
            u = event.get("usage", {})
            state["usage"]["input"] += u.get("prompt_tokens", 0); state["usage"]["output"] += u.get("completion_tokens", 0)

        # Persistent UI Observability: Commit assistant chunks occasionally
        if e_type in ("content", "reasoning"):
            await self._update_assistant_db(db, session_id, event, state)

    async def _update_assistant_db(self, db, session_id, event, state):
        """Incrementally saves the assistant's response to the DB for real-time frontend visibility."""
        if not state["msg"]:
            state["msg"] = models.Message(session_id=session_id, sender="assistant", content="")
            db.add(state["msg"])
            await async_db_op(db.commit)

        if event["type"] == "content": state["msg"].content += event["content"]
        elif event["type"] == "reasoning" and hasattr(state["msg"], "reasoning_content"):
            state["msg"].reasoning_content = (state["msg"].reasoning_content or "") + event["content"]

        if (state["usage"]["input"] + state["usage"]["output"]) % 50 == 0:
            try: await async_db_op(db.commit)
            except: await async_db_op(db.rollback)

    async def _finalize_assistant_message(self, db, session_id, state) -> models.Message:
        """Ensures the final assistant message is correctly persisted and closed."""
        msg = state["msg"] or models.Message(session_id=session_id, sender="assistant", content="")
        msg.content = state["answer"]
        if hasattr(msg, "reasoning_content"): msg.reasoning_content = state["reasoning"]
        if not state["msg"]: db.add(msg)
        await async_db_op(db.commit)
        return msg

    def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
        """Retrieves and sorts the conversational history for a session."""
        session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first()
        return sorted(session.messages, key=lambda m: m.created_at) if session else None