import threading
import queue
import time
import logging
from typing import Any, Optional, Callable
from . import agent_pb2, agent_pb2_grpc
from .transport import IMeshTransport, IMeshListener

logger = logging.getLogger(__name__)

class GrpcMeshTransport(IMeshTransport):
    """
    gRPC implementation of the Mesh Transport.
    Encapsulates the bidirectional stream and auto-reconnection logic.
    """
    def __init__(self, node_id: str, stub_factory: Callable[[], tuple], auth_token: str = ""):
        self.node_id = node_id
        self.stub_factory = stub_factory # Callable returning (stub, channel)
        self.auth_token = auth_token
        self.listener = None
        self.stub = None
        self.channel = None
        self.send_queue = queue.PriorityQueue()
        self._stop_event = threading.Event()
        self._connected = False
        self.last_activity = 0
    def handshake(self) -> bool:
        self._refresh_stub()
        try:
            req = agent_pb2.RegistrationRequest(
                node_id=self.node_id,
                auth_token=self.auth_token,
                node_description="Portable Mesh Node"
            )
            res = self.stub.SyncConfiguration(req)
            if res.success:
                logger.info(f"[Mesh] Handshake successful for {self.node_id}")
                # Optional: Handle policy res.policy
                return True
            else:
                logger.error(f"[Mesh] Handshake REJECTED for {self.node_id}: {res.error_message}")
                return False
        except Exception as e:
            logger.error(f"[Mesh] Handshake FAILED for {self.node_id}: {e}")
            return False

    def connect(self):
        self._stop_event.clear()
        self._refresh_stub()
        threading.Thread(target=self._run_stream, daemon=True, name="GrpcTransportStream").start()

    def set_listener(self, listener: IMeshListener):
        self.listener = listener

    def _refresh_stub(self):
        if self.channel:
            try: self.channel.close()
            except: pass
        self.stub, self.channel = self.stub_factory()

    def _run_stream(self):
        retry_count = 0
        while not self._stop_event.is_set():
            try:
                def _gen():
                    # Initial announcement
                    yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
                    
                    last_heartbeat = time.time()
                    while not self._stop_event.is_set():
                        try:
                            # PriorityQueue returns (priority, ts, msg)
                            item = self.send_queue.get(timeout=1.0)
                            yield item[2]
                        except queue.Empty:
                            pass
                        
                        # Transport-level KeepAlive
                        if time.time() - last_heartbeat >= 10.0:
                            yield agent_pb2.ClientTaskMessage(
                                skill_event=agent_pb2.SkillEvent(keep_alive=True)
                            )
                            last_heartbeat = time.time()

                logger.info(f"[*] Opening gRPC TaskStream for {self.node_id}...")
                responses = self.stub.TaskStream(_gen())
                self._connected = True
                self.last_activity = time.time()
                logger.info(f"[✅] gRPC Mesh Transport Online for {self.node_id}")
                retry_count = 0
                
                for msg in responses:
                    self.last_activity = time.time()
                    if self.listener:
                        self.listener.on_message(msg)
                
                self._connected = False
                logger.warning(f"[📶] gRPC Stream closed by server for {self.node_id}")
            except Exception as e:
                self._connected = False
                if self.listener:
                    self.listener.on_error(e)
                
                if not self._stop_event.is_set():
                    retry_count += 1
                    backoff = min(30, 2 * retry_count)
                    logger.error(f"[❌] gRPC Stream Error ({type(e).__name__}): {e}. Reconnecting in {backoff}s... (Attempt {retry_count})")
                    self._refresh_stub()
                    time.sleep(backoff)

    def send(self, message: Any, priority: int = 1):
        # PriorityQueue expects (priority, timestamp, item)
        self.send_queue.put((priority, time.time(), message))

    def close(self):
        self._stop_event.set()
        if self.channel:
            self.channel.close()
        self._connected = False
        self.listener.on_close()

    def is_connected(self) -> bool:
        return self._connected
