Newer
Older
cortex-hub / ai-hub / app / core / services / workspace.py
import dspy
import json
import uuid
import re
import logging
from datetime import datetime
import ast  # Import the Abstract Syntax Trees module
from typing import Dict, Any, Callable, Awaitable, List
from fastapi import WebSocket,Depends
from sqlalchemy.orm import Session,joinedload
from app.db import models
from app.db import file_retriever_models
from app.db.session import SessionLocal
from app.core.providers.factory import get_llm_provider
from app.core.pipelines.file_selector import CodeRagFileSelector
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.pipelines.question_decider import CodeRagQuestionDecider
from app.core.pipelines.context_compressor import StringContextCompressor
from app.core.retrievers.file_retriever import FileRetriever
# A type hint for our handler functions
MessageHandler = Callable[[WebSocket, Dict[str, Any]], Awaitable[None]]
# Configure logging
logger = logging.getLogger(__name__)

class WorkspaceService:
    """
    Manages the full lifecycle of an AI workspace session, including
    handling various message types and dispatching them to the correct handlers.
    """
    def __init__(self):
        # The dispatcher map: keys are message types, values are handler functions
        self.message_handlers: Dict[str, MessageHandler] = {
            # "select_folder_response": self.handle_select_folder_response,
            "list_directory_response": self.handle_list_directory_response,
            "file_content_response": self.handle_files_content_response,
            "execute_command_response": self.handle_command_output,
            "chat_message": self.handle_chat_message,
            # Add more message types here as needed
        }
        # Centralized map of commands that can be sent to the client
        self.command_map: Dict[str, Dict[str, Any]] = {
            "list_directory": {"type": "list_directory", "description": "Request a list of files and folders in the current directory."},
            "get_file_content": {"type": "get_file_content", "description": "Request the content of a specific file."},
            "execute_command": {"type": "execute_command", "description": "Request to execute a shell command."},
            # Define more commands here
        }
        # Per-websocket session state management
        self.sessions: Dict[str, Dict[str, Any]] = {}
        self.db = SessionLocal()
        self.file_retriever = FileRetriever()

    # --- New helper function for reuse ---
    async def _update_file_content(self, request_id: uuid.UUID, files_with_content: List[Dict[str, Any]]):
        """
        Updates the content of existing file records in the database.

        This function is called after the client has sent the content for the
        files selected by the AI. It iterates through the provided file data,
        finds the corresponding database record, and updates its 'content' field.
        """
        if not files_with_content:
            logger.warning("No files with content provided to update.")
            return

        logger.info(f"Starting content update for {len(files_with_content)} files for request {request_id}")

        try:
            # Fetch all files for the given request ID to build a quick lookup map
            retrieved_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(
                request_id=request_id
            ).all()
            file_map = {file.file_path: file for file in retrieved_files}
            
            updated_count = 0
            for file_info in files_with_content:
                file_path = file_info.get("filepath")
                content = file_info.get("content")

                if not file_path or content is None:
                    logger.warning("Skipping file with missing filename or content.")
                    continue

                # Find the corresponding record in our map
                if file_path in file_map:
                    db_file = file_map[file_path]
                    db_file.content = content
                    updated_count += 1
                    logger.debug(f"Updated content for file: {file_path}")
                else:
                    logger.warning(f"File {file_path} not found in database for request {request_id}, skipping content update.")
            
            # Commit the changes to the database
            self.db.commit()
            logger.info(f"Successfully updated content for {updated_count} files.")
        
        except Exception as e:
            self.db.rollback()
            logger.error(f"Failed to update file content for request {request_id}: {e}")
            raise

    async def _get_or_create_file_request(self, session_id: int, path: str, prompt: str) -> file_retriever_models.FileRetrievalRequest:
        """
        Retrieves an existing FileRetrievalRequest or creates a new one if it doesn't exist.
        """
        file_request = self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(
            session_id=session_id, directory_path=path
        ).first()

        if not file_request:
            file_request = file_retriever_models.FileRetrievalRequest(
                session_id=session_id,
                question=prompt,
                directory_path=path
            )
            self.db.add(file_request)
            self.db.commit()
            self.db.refresh(file_request)
        else:
            # If file_request is found, update it with the latest prompt
            file_request.question = prompt
            self.db.commit()
            self.db.refresh(file_request)
        return file_request
    
    async def _get_file_request_by_id(self, request_id: uuid.UUID) -> file_retriever_models.FileRetrievalRequest:
        """
        Retrieves a FileRetrievalRequest by its ID.
        """
        return self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(id=request_id).first()

    async def _store_retrieved_files(self, request_id: uuid.UUID, files: List[Dict[str, Any]]):
        """
        Synchronizes the database's retrieved files with the client's file list.

        This function compares existing files against new ones and performs
        updates, additions, or deletions as necessary.
        """
        # 1. Get existing files from the database for this request
        existing_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(request_id=request_id).all()
        existing_files_map = {file.file_path: file for file in existing_files}
        
        # Keep track of which existing files are also in the incoming list
        incoming_file_paths = set()
        
        # 2. Iterate through incoming files to handle updates and additions
        for file_info in files:
            file_path = file_info.get("path")
            last_modified_timestamp_ms = file_info.get("lastModified")

            if not file_path or last_modified_timestamp_ms is None:
                logger.warning("Skipping file with missing path or timestamp.")
                continue
                
            last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp_ms / 1000.0)
            incoming_file_paths.add(file_path)

            # Check if the file already exists in the database
            if file_path in existing_files_map:
                db_file = existing_files_map[file_path]
                # Compare the last modified timestamps
                if last_modified_datetime > db_file.last_updated:
                    # Case: File has been updated, so override the existing record.
                    logger.info(f"Updating file {file_path}. New timestamp: {last_modified_datetime}")
                    db_file.last_updated = last_modified_datetime
                    # The content remains empty for now, as it will be fetched later.
                else:
                    # Case: File is identical or older, do nothing.
                    # logger.debug(f"File {file_path} is identical or older, skipping.")
                    pass
                    
            else:
                # Case: This is a newly introduced file.
                logger.info(f"Adding new file: {file_path}")
                new_file = file_retriever_models.RetrievedFile(
                    request_id=request_id,
                    file_path=file_path,
                    file_name=file_info.get("name", ""),
                    content="",  # Content is deliberately left empty.
                    type="original",
                    last_updated=last_modified_datetime,
                )
                self.db.add(new_file)

        # 3. Purge non-existing files
        # Find files in the database that were not in the incoming list
        files_to_purge = [
            file for file in existing_files if file.file_path not in incoming_file_paths
        ]
        if files_to_purge:
            logger.info(f"Purging {len(files_to_purge)} non-existing files.")
            for file in files_to_purge:
                self.db.delete(file)

        # 4. Commit all changes (updates, additions, and deletions) in a single transaction
        self.db.commit()
        logger.info("File synchronization complete.")

    # def generate_request_id(self) -> str:
    #     """Generates a unique request ID."""
    #     return str(uuid.uuid4())


        
    async def send_command(self, websocket: WebSocket, command_name: str, data: Dict[str, Any] = {}):
        if command_name not in self.command_map:
            raise ValueError(f"Unknown command: {command_name}")
        
        message_to_send = {
            "type": self.command_map[command_name]["type"],
            **data,
        }
        
        await websocket.send_text(json.dumps(message_to_send))
        
    async def dispatch_message(self, websocket: WebSocket, message: Dict[str, Any]):
        """
        Routes an incoming message to the appropriate handler based on its 'type'.
        Retrieves session state to maintain context.
        """
        message_type = message.get("type")
        # request_id = message.get("request_id")
        # round_num = message.get("round")
        
        # In a real-world app, you'd retrieve historical data based on request_id or session_id
        # For this example, we'll just print it.
        # print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})")
        # logger.info(f"Received message: {message}")
        
        handler = self.message_handlers.get(message_type)
        if handler:
            logger.debug(f"Dispatching to handler for message type: {message_type}")
            await handler(websocket, message)
        else:
            print(f"Warning: No handler found for message type: {message_type}")
            await websocket.send_text(json.dumps({"type": "error", "content": f"Unknown message type: {message_type}"}))

    # async def handle_select_folder_response(self, websocket:WebSocket, data: Dict[str, Any]):
    #     """Handles the client's response to a select folder response."""
    #     path = data.get("path")
    #     request_id = data.get("request_id")
    #     if request_id is None:
    #         await websocket.send_text(json.dumps({
    #             "type": "error",
    #             "content": "Error: request_id is missing in the response."
    #         }))
    #         return
    #     print(f"Received folder selected (request_id: {request_id}): Path: {path}")
        
        # After a folder is selected, the next step is to list its contents.
        # This now uses the send_command helper.
        # await self.send_command(websocket, "list_directory", data={"path": path, "request_id": request_id})

    async def handle_list_directory_response(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles the client's response to a list_directory request."""
        logger.debug(f"List directory response data: {data}")
        files = data.get("files", [])
        if not files:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: No files found in the directory."
            }))
            return
        request_id = data.get("request_id")
        if not request_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: request_id is missing in the response."
            }))
            return
        file_request = await self._get_file_request_by_id(uuid.UUID(request_id))
        if not file_request:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: No matching file request found for request_id {request_id}."
            }))
            return
        await self._store_retrieved_files(request_id=uuid.UUID(request_id), files=files)
        
        provider_name = data.get("provider_name", "gemini")
        llm_provider = get_llm_provider(provider_name)
        cfs = CodeRagFileSelector()
        
        with dspy.context(lm=llm_provider):
            raw_answer_text  = await cfs(
                question=file_request.question,
                retrieved_data = self.file_retriever.retrieve_by_request_id(self.db, request_id=request_id)
            )
        
        try:
            answer_text = ast.literal_eval(raw_answer_text)
            if not isinstance(answer_text, list):
                raise ValueError("Parsed result is not a list.")
        except (ValueError, SyntaxError) as e:
    # Handle cases where the LLM output is not a valid list string.
            print(f"Error parsing LLM output: {e}")
            answer_text = []  # Default to an empty list to prevent errors.
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"Warning: AI's file list could not be parsed. Error: {e}"
            }))
            return

        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"AI selected files: {answer_text}. Now requesting file content."
        }))
        
        # After getting the AI's selected files, we send a command to the client to get their content.
        await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})

    async def handle_files_content_response(self, websocket: WebSocket, data: Dict[str, Any]):
        """
        Handles the content of a list of files sent by the client.
        """
        files_data: List[Dict[str, str]] = data.get("files", [])
        request_id = data.get("request_id")
        
        if not files_data:
            print(f"Warning: No files data received for request_id: {request_id}")
            return
            
        print(f"Received content for {len(files_data)} files (request_id: {request_id}).")
        await self._update_file_content(request_id=uuid.UUID(request_id), files_with_content=files_data)
        
        # Retrieve the updated context from the database
        context_data = self.file_retriever.retrieve_by_request_id(self.db, request_id=request_id)
        
        if not context_data:
            print(f"Error: Context not found for request_id: {request_id}")
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "An internal error occurred. Please try again."
            }))
            return

        # Use the LLM to make a decision
        with dspy.context(lm=get_llm_provider("gemini")):
            crqd = CodeRagQuestionDecider()
            raw_answer_text, decision, code_diff = await crqd(
                question=context_data.get("question", ""),
                history="",
                retrieved_data=context_data
            )

        if decision == "files": 
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"AI decided more files are needed: {raw_answer_text}."
            }))
            try:
                # The LLM is instructed to provide a JSON list, so we parse it
                file_list = json.loads(raw_answer_text)
                if not isinstance(file_list, list):
                    raise ValueError("Parsed result is not a list.")
            except (ValueError, json.JSONDecodeError) as e:
                print(f"Error parsing LLM output: {e}")
                file_list = []
                await websocket.send_text(json.dumps({
                    "type": "thinking_log",
                    "content": f"Warning: AI's file list could not be parsed. Error: {e}"
                }))
                return
            
            await self.send_command(websocket, "get_file_content", data={"filepaths": file_list, "request_id": request_id})
        
        elif decision == "code_change":
            await websocket.send_text(json.dumps({
                "type": "chat_message",
                "content": raw_answer_text,
                "code_diff": code_diff
            }))
            
        else: # decision is "answer"
            await websocket.send_text(json.dumps({
                "type": "chat_message",
                "content": raw_answer_text
            }))
        
    async def handle_command_output(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles the output from a command executed by the client."""
        command = data.get("command")
        output = data.get("output")
        request_id = data.get("request_id")
        
        print(f"Received output for command '{command}' (request_id: {request_id}). Output: {output}")
        
        # The AI would process the command output to determine the next step
        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"Command '{command}' completed. Analyzing output."
        }))
    
    async def handle_chat_message(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles incoming chat messages from the client."""
        # TODO: Enhance this function to process the chat message and determine the next action.
        prompt = data.get("content")
        provider_name = data.get("provider_name", "gemini")
        session_id = data.get("session_id")
        if session_id is None:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: session_id is required for chat messages."
            }))
            return
        session = self.db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: Session with ID {session_id} not found."
            }))
            return
        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        self.db.add(user_message)
        self.db.commit()
        self.db.refresh(user_message)

        path = data.get("path", "")
        if path:
            # If file path is provided, initiate file retrieval process.
            file_request = await self._get_or_create_file_request(session_id, path, prompt)
            await self.send_command(websocket, "list_directory", data={"request_id": str(file_request.id)})
            return
        llm_provider = get_llm_provider(provider_name)
        chat = DspyRagPipeline()
        with dspy.context(lm=llm_provider):
            answer_text = await chat(question=prompt, history=session.messages, context_chunks=[])
                # Save assistant's response
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        self.db.add(assistant_message)
        self.db.commit()
        self.db.refresh(assistant_message)
        
        # 📝 Add this section to send the response back to the client
        # The client-side `handleChatMessage` handler will process this message
        await websocket.send_text(json.dumps({
            "type": "chat_message",
            "content": answer_text
        }))
        logger.info(f"Sent chat response to client: {answer_text}")