Newer
Older
cortex-hub / ai-hub / app / core / services / workspace.py
import dspy
import json
import uuid
import os
import re
import logging
from datetime import datetime
import ast  # Import the Abstract Syntax Trees module
from typing import Dict, Any, Callable, Awaitable, List, Optional
from fastapi import WebSocket,Depends
from sqlalchemy.orm import Session,joinedload
from app.db import models
from app.db import file_retriever_models
from app.db.session import SessionLocal
from app.core.providers.factory import get_llm_provider
from app.core.pipelines.file_selector import CodeRagFileSelector
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.pipelines.question_decider import CodeRagQuestionDecider

# A type hint for our handler functions
MessageHandler = Callable[[WebSocket, Dict[str, Any]], Awaitable[None]]
# Configure logging
logger = logging.getLogger(__name__)

class WorkspaceService:
    """
    Manages the full lifecycle of an AI workspace session, including
    handling various message types and dispatching them to the correct handlers.
    """
    def __init__(self):
        # The dispatcher map: keys are message types, values are handler functions
        self.message_handlers: Dict[str, MessageHandler] = {
            # "select_folder_response": self.handle_select_folder_response,
            "list_directory_response": self.handle_list_directory_response,
            "file_content_response": self.handle_files_content_response,
            "execute_command_response": self.handle_command_output,
            "chat_message": self.handle_chat_message,
            # Add more message types here as needed
        }
        # Centralized map of commands that can be sent to the client
        self.command_map: Dict[str, Dict[str, Any]] = {
            "list_directory": {"type": "list_directory", "description": "Request a list of files and folders in the current directory."},
            "get_file_content": {"type": "get_file_content", "description": "Request the content of a specific file."},
            "execute_command": {"type": "execute_command", "description": "Request to execute a shell command."},
            # Define more commands here
        }
        # Per-websocket session state management
        self.sessions: Dict[str, Dict[str, Any]] = {}
        self.db = SessionLocal()

    # --- New helper function for reuse ---
    async def _update_file_content(self, request_id: uuid.UUID, files_with_content: List[Dict[str, Any]]):
        """
        Updates the content of existing file records in the database.

        This function is called after the client has sent the content for the
        files selected by the AI. It iterates through the provided file data,
        finds the corresponding database record, and updates its 'content' field.
        """
        if not files_with_content:
            logger.warning("No files with content provided to update.")
            return

        logger.info(f"Starting content update for {len(files_with_content)} files for request {request_id}")

        try:
            # Fetch all files for the given request ID to build a quick lookup map
            retrieved_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(
                request_id=request_id
            ).all()
            file_map = {file.file_path: file for file in retrieved_files}
            
            updated_count = 0
            for file_info in files_with_content:
                file_path = file_info.get("filepath")
                content = file_info.get("content")

                if not file_path or content is None:
                    logger.warning("Skipping file with missing filename or content.")
                    continue

                # Find the corresponding record in our map
                if file_path in file_map:
                    db_file = file_map[file_path]
                    db_file.content = content
                    updated_count += 1
                    logger.debug(f"Updated content for file: {file_path}")
                else:
                    logger.warning(f"File {file_path} not found in database for request {request_id}, skipping content update.")
            
            # Commit the changes to the database
            self.db.commit()
            logger.info(f"Successfully updated content for {updated_count} files.")
        
        except Exception as e:
            self.db.rollback()
            logger.error(f"Failed to update file content for request {request_id}: {e}")
            raise

    async def _get_or_create_file_request(self, session_id: int, path: str, prompt: str) -> file_retriever_models.FileRetrievalRequest:
        """
        Retrieves an existing FileRetrievalRequest or creates a new one if it doesn't exist.
        """
        file_request = self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(
            session_id=session_id, directory_path=path
        ).first()

        if not file_request:
            file_request = file_retriever_models.FileRetrievalRequest(
                session_id=session_id,
                question=prompt,
                directory_path=path
            )
            self.db.add(file_request)
            self.db.commit()
            self.db.refresh(file_request)
        else:
            # If file_request is found, update it with the latest prompt
            file_request.question = prompt
            self.db.commit()
            self.db.refresh(file_request)
        return file_request
    
    async def _get_file_request_by_id(self, request_id: uuid.UUID) -> file_retriever_models.FileRetrievalRequest:
        """
        Retrieves a FileRetrievalRequest by its ID.
        """
        return self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(id=request_id).first()

    async def _store_retrieved_files(self, request_id: uuid.UUID, files: List[Dict[str, Any]]):
        """
        Synchronizes the database's retrieved files with the client's file list.

        This function compares existing files against new ones and performs
        updates, additions, or deletions as necessary.
        """
        # 1. Get existing files from the database for this request
        existing_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(request_id=request_id).all()
        existing_files_map = {file.file_path: file for file in existing_files}
        
        # Keep track of which existing files are also in the incoming list
        incoming_file_paths = set()
        
        # 2. Iterate through incoming files to handle updates and additions
        for file_info in files:
            file_path = file_info.get("path")
            last_modified_timestamp_ms = file_info.get("lastModified")

            if not file_path or last_modified_timestamp_ms is None:
                logger.warning("Skipping file with missing path or timestamp.")
                continue
                
            last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp_ms / 1000.0)
            incoming_file_paths.add(file_path)

            # Check if the file already exists in the database
            if file_path in existing_files_map:
                db_file = existing_files_map[file_path]
                # Compare the last modified timestamps
                if last_modified_datetime > db_file.last_updated:
                    # Case: File has been updated, so override the existing record.
                    logger.info(f"Updating file {file_path}. New timestamp: {last_modified_datetime}")
                    db_file.last_updated = last_modified_datetime
                    # The content remains empty for now, as it will be fetched later.
                else:
                    # Case: File is identical or older, do nothing.
                    # logger.debug(f"File {file_path} is identical or older, skipping.")
                    pass
                    
            else:
                # Case: This is a newly introduced file.
                logger.info(f"Adding new file: {file_path}")
                new_file = file_retriever_models.RetrievedFile(
                    request_id=request_id,
                    file_path=file_path,
                    file_name=file_info.get("name", ""),
                    content="",  # Content is deliberately left empty.
                    type="original",
                    last_updated=last_modified_datetime,
                )
                self.db.add(new_file)

        # 3. Purge non-existing files
        # Find files in the database that were not in the incoming list
        files_to_purge = [
            file for file in existing_files if file.file_path not in incoming_file_paths
        ]
        if files_to_purge:
            logger.info(f"Purging {len(files_to_purge)} non-existing files.")
            for file in files_to_purge:
                self.db.delete(file)

        # 4. Commit all changes (updates, additions, and deletions) in a single transaction
        self.db.commit()
        logger.info("File synchronization complete.")

    # def generate_request_id(self) -> str:
    #     """Generates a unique request ID."""
    #     return str(uuid.uuid4())

    async def _retrieve_by_request_id(self, db: Session, request_id: str) -> Optional[Dict[str, Any]]:
        """
        Retrieves a FileRetrievalRequest and all its associated files from the database,
        returning the data in a well-formatted JSON-like dictionary.

        Args:
            db: The SQLAlchemy database session.
            request_id: The UUID of the FileRetrievalRequest.

        Returns:
            A dictionary containing the request and file data, or None if the request is not found.
        """
        try:
            # Convert string request_id to UUID object for the query
            request_uuid = uuid.UUID(request_id)
        except ValueError:
            print(f"Invalid UUID format for request_id: {request_id}")
            return None

        # Fetch the request and its related files in a single query using join
        request = db.query(file_retriever_models.FileRetrievalRequest).filter(
            file_retriever_models.FileRetrievalRequest.id == request_uuid
        ).options(
            # Eagerly load the retrieved_files to avoid N+1 query problem
            joinedload(file_retriever_models.FileRetrievalRequest.retrieved_files)
        ).first()

        if not request:
            return None

        # Build the dictionary to represent the JSON structure
        retrieved_data = {
            "request_id": str(request.id),
            "question": request.question,
            "directory_path": request.directory_path,
            "session_id": request.session_id,
            "created_at": request.created_at.isoformat() if request.created_at else None,
            "retrieved_files": []
        }

        for file in request.retrieved_files:
            if file.content:
                # For files with content, show the full detailed structure
                file_data = {
                    "file_path": file.file_path,
                    "content": file.content,
                    "id": str(file.id),
                    "name": file.file_name,
                    "type": file.type,
                    "last_updated": file.last_updated.isoformat() if file.last_updated else None,
                    "created_at": file.created_at.isoformat() if file.created_at else None,
                }
            else:
                # For empty files, use a compact representation
                file_data = {
                    "file_path": file.file_path,
                    "type": file.type
                }
            retrieved_data["retrieved_files"].append(file_data)

        return retrieved_data

    async def get_file_content_by_request_id_and_path(self, db: Session, request_id: uuid.UUID, file_path: str) ->str:
        """
        Retrieves a FileRetrievalRequest by its ID.
        """
        retrievedFile =  db.query(file_retriever_models.RetrievedFile).filter_by(request_id = request_id , file_path=file_path).first()
        if retrievedFile and retrievedFile.content:
            return retrievedFile.content
        else:
            logger.warning(f"File with path {file_path} not found for request ID {request_id} or has no content.")
            return ""
    
    async def _handle_code_change_response(self, db: Session, request_id: str, code_diff: str) -> List[Dict[str, Any]]: 
        """
        Parses the diff, retrieves original file content, and returns a structured,
        per-file dictionary for the client.
        """
        # Normalize the diff string to ensure consistent splitting, handling cases where 
        # the separator may be missing a leading newline.
        normalized_diff = re.sub(r'(?<!^)(?<!\n)--- a\/', '\n--- a/', code_diff)

        # 1. Split the monolithic code_diff string into per-file diffs.
        # The updated regex is more robust and handles both standard "diff --git" headers and 
        # non-standard file separators like "--- a/".
        per_file_diffs = re.split(r'(?=\ndiff --git a\/|\n--- a\/)', normalized_diff)
        
        # 2. Iterate through each per-file diff to get file path and retrieve content.
        files_with_diff_and_content = []
        
        for file_diff in per_file_diffs:
            if not file_diff.strip():
                continue

            # Use a regex to find the file path from the "--- a/path" line
            path_match = re.search(r'--- a(.*)', file_diff)
            if path_match:
                file_path = path_match.group(1).strip()
                
                # Retrieve the original content for this specific file.
                original_content = await self.get_file_content_by_request_id_and_path(
                    db, 
                    uuid.UUID(request_id), 
                    file_path
                )
                
                # Group the file path, its diff, and its original content.
                files_with_diff_and_content.append({
                    "filepath": file_path,
                    "diff": file_diff,
                    "original_content": original_content,
                    "new_content": self._apply_diff(original_content, file_diff)
                })
        return files_with_diff_and_content
        
    
    async def get_files_by_request_id(self, db: Session, request_id: str) -> Optional[List[str]]:
        """
        Retrieves all files associated with a FileRetrievalRequest from the database,
        returning the data in a list of JSON-like dictionaries.

        Args:
            db: The SQLAlchemy database session.
            request_id: The UUID of the FileRetrievalRequest.

        Returns:
            A list of dictionaries containing file data, or None if the request is not found.
        """
        try:
            request_uuid = uuid.UUID(request_id)
        except ValueError:
            print(f"Invalid UUID format for request_id: {request_id}")
            return None

        request = db.query(file_retriever_models.FileRetrievalRequest).filter(
            file_retriever_models.FileRetrievalRequest.id == request_uuid
        ).options(
            joinedload(file_retriever_models.FileRetrievalRequest.retrieved_files)
        ).first()

        if not request:
            return None

        retrieved_files = []

        for file in request.retrieved_files:
            retrieved_files.append(file.file_path)

        return retrieved_files
    
    def _format_diff(self, raw_diff: str) -> str:
        # Remove Markdown-style code block markers
        content = re.sub(r'^```diff\n|```$', '', raw_diff.strip(), flags=re.MULTILINE)
        
        # Unescape common sequences
        content = content.encode('utf-8').decode('unicode_escape')

        return content

    def _apply_diff(self, original_content: str, file_diff: str) -> str:
        """
        Applies a unified diff to the original content and returns the new content.

        Args:
            original_content: The original file content as a single string.
            file_diff: The unified diff string.

        Returns:
            The new content with the diff applied.
        """
        # Handle the case where the original content is empty.
        if not original_content:
            new_content: List[str] = []
            for line in file_diff.splitlines(keepends=True):
                if line.startswith('+') and not line.startswith('+++'):
                    new_content.append(line[1:])
            return ''.join(new_content)

        original_lines = original_content.splitlines(keepends=True)
        diff_lines = file_diff.splitlines(keepends=True)

        i = 0
        new_content: List[str] = []
        orig_idx = 0

        while i < len(diff_lines):
            # Skip diff headers like --- and +++
            if diff_lines[i].startswith('---') or diff_lines[i].startswith('+++'):
                i += 1
                continue

            # Hunk header
            if not diff_lines[i].startswith('@@'):
                i += 1
                continue

            hunk_header = diff_lines[i]
            m = re.match(r'^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', hunk_header)
            if not m:
                raise ValueError(f"Invalid hunk header: {hunk_header.strip()}")

            orig_start = int(m.group(1)) - 1  # convert from 1-based to 0-based index
            i += 1

            # Copy unchanged lines before this hunk
            while orig_idx < orig_start:
                new_content.append(original_lines[orig_idx])
                orig_idx += 1

            # Process lines in hunk
            while i < len(diff_lines):
                line = diff_lines[i]

                if line.startswith('@@'):
                    # Start of next hunk
                    break
                elif line.startswith(' '):
                    # Context line
                    if orig_idx < len(original_lines):
                        new_content.append(original_lines[orig_idx])
                        orig_idx += 1
                elif line.startswith('-'):
                    # Removed line
                    orig_idx += 1
                elif line.startswith('+'):
                    # Added line
                    new_content.append(line[1:])
                i += 1

        # Add remaining lines from original
        new_content.extend(original_lines[orig_idx:])

        return ''.join(new_content)

    
    async def send_command(self, websocket: WebSocket, command_name: str, data: Dict[str, Any] = {}):
        if command_name not in self.command_map:
            raise ValueError(f"Unknown command: {command_name}")
        
        message_to_send = {
            "type": self.command_map[command_name]["type"],
            **data,
        }
        
        await websocket.send_text(json.dumps(message_to_send))
        
    async def dispatch_message(self, websocket: WebSocket, message: Dict[str, Any]):
        """
        Routes an incoming message to the appropriate handler based on its 'type'.
        Retrieves session state to maintain context.
        """
        message_type = message.get("type")
        # request_id = message.get("request_id")
        # round_num = message.get("round")
        
        # In a real-world app, you'd retrieve historical data based on request_id or session_id
        # For this example, we'll just print it.
        # print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})")
        # logger.info(f"Received message: {message}")
        
        handler = self.message_handlers.get(message_type)
        if handler:
            logger.debug(f"Dispatching to handler for message type: {message_type}")
            await handler(websocket, message)
        else:
            print(f"Warning: No handler found for message type: {message_type}")
            await websocket.send_text(json.dumps({"type": "error", "content": f"Unknown message type: {message_type}"}))

    # async def handle_select_folder_response(self, websocket:WebSocket, data: Dict[str, Any]):
    #     """Handles the client's response to a select folder response."""
    #     path = data.get("path")
    #     request_id = data.get("request_id")
    #     if request_id is None:
    #         await websocket.send_text(json.dumps({
    #             "type": "error",
    #             "content": "Error: request_id is missing in the response."
    #         }))
    #         return
    #     print(f"Received folder selected (request_id: {request_id}): Path: {path}")
        
        # After a folder is selected, the next step is to list its contents.
        # This now uses the send_command helper.
        # await self.send_command(websocket, "list_directory", data={"path": path, "request_id": request_id})

    async def handle_list_directory_response(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles the client's response to a list_directory request."""
        logger.debug(f"List directory response data: {data}")
        files = data.get("files", [])
        if not files:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: No files found in the directory."
            }))
            return
        request_id = data.get("request_id")
        if not request_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: request_id is missing in the response."
            }))
            return
        file_request = await self._get_file_request_by_id(uuid.UUID(request_id))
        if not file_request:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: No matching file request found for request_id {request_id}."
            }))
            return
        await self._store_retrieved_files(request_id=uuid.UUID(request_id), files=files)
        
        provider_name = data.get("provider_name", "gemini")
        llm_provider = get_llm_provider(provider_name)
        cfs = CodeRagFileSelector()
        
        with dspy.context(lm=llm_provider):
            raw_answer_text  = await cfs(
                question=file_request.question,
                retrieved_data = await self.get_files_by_request_id(self.db, request_id=request_id)
            )
        try:
            # Use ast.literal_eval for a safe and reliable parse
            answer_text = ast.literal_eval(raw_answer_text)
        except (ValueError, SyntaxError) as e:
    # Handle cases where the LLM output is not a valid list string.
            print(f"Error parsing LLM output: {e}")
            answer_text = []  # Default to an empty list to prevent errors.
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"Warning: AI's file list could not be parsed. Error: {e}"
            }))
            return
        
        if len(answer_text) == 0:
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": "AI did not select any files to retrieve content for."
            }))
            await self.handle_files_content_response(websocket, {"files": [], "request_id": request_id})
            return
        
        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"AI selected files: {answer_text}. Now requesting file content."
        }))
            
        # After getting the AI's selected files, we send a command to the client to get their content.
        await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})

    async def handle_files_content_response(self, websocket: WebSocket, data: Dict[str, Any]):
        """
        Handles the content of a list of files sent by the client.
        """
        files_data: List[Dict[str, str]] = data.get("files", [])
        request_id = data.get("request_id")
        session_id = data.get("session_id")
        
        if not files_data:
            print(f"Warning: No files data received for request_id: {request_id}")
        else:          
            print(f"Received content for {len(files_data)} files (request_id: {request_id}).")
            await self._update_file_content(request_id=uuid.UUID(request_id), files_with_content=files_data)
        
        if not session_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: session_id is required to process file content."
            }))
            return
        
        if not request_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: request_id is required to process file content."
            }))
            return 
        
        # Retrieve the updated context from the database
        context_data = await self._retrieve_by_request_id(self.db, request_id=request_id)
        
        if not context_data:
            print(f"Error: Context not found for request_id: {request_id}")
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "An internal error occurred. Please try again."
            }))
            return
        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"AI is analyzing the retrieved files to determine next steps."
        }))

        session = self.db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()

        # Use the LLM to make a decision
        with dspy.context(lm=get_llm_provider(provider_name="gemini")):
            crqd = CodeRagQuestionDecider()
            raw_answer_text, reasoning, decision, code_diff = await crqd(
                question=context_data.get("question", ""),
                history=session.messages,
                retrieved_data=context_data
            )
        dspy.inspect_history(n=1)  # Inspect the last DSPy operation for debugging
        if decision in [ "code_change", "answer"]:
            assistant_message = models.Message(session_id=session_id, sender="assistant", content=raw_answer_text)
            self.db.add(assistant_message)
            self.db.commit()
            self.db.refresh(assistant_message)

        if decision == "files": 
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"AI decided more files are needed: {raw_answer_text}."
            }))
            try:
                # Use regex to find the JSON content, including any surrounding newlines and code blocks
                json_match = re.search(r'\[.*\]', raw_answer_text, re.DOTALL)
                if json_match:
                    # Extract the matched JSON string
                    json_string = json_match.group(0)
                    
                    # Use ast.literal_eval for a safe and reliable parse
                    answer_text = ast.literal_eval(json_string)
                    
                    if not isinstance(answer_text, list):
                        raise ValueError("Parsed result is not a list.")
                else:
                    # Fallback if no markdown is found
                    answer_text = ast.literal_eval(raw_answer_text)
                    if not isinstance(answer_text, list):
                        raise ValueError("Parsed result is not a list.")
            except (ValueError, SyntaxError) as e:
                print(f"Error parsing LLM output: {e}")
                answer_text = []
                await websocket.send_text(json.dumps({
                    "type": "thinking_log",
                    "content": f"Warning: AI's file list could not be parsed. Error: {e}"
                }))
                return
            
            await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})
        
        elif decision == "code_change":
            diffs =await self._handle_code_change_response(db=self.db, request_id=request_id, code_diff=code_diff)
            for diff in diffs:
                diff["diff"] = self._format_diff(diff.get("diff",""))
            payload = json.dumps({
                "type": "chat_message",
                "content": raw_answer_text,
                "reasoning": reasoning,
                "dicision" : decision,
                "code_diff":diffs
            })
            logger.info(f"Sending code change response to client: {payload}")
            await websocket.send_text(payload)
            
        else: # decision is "answer"
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"Answering user's question directly."
            }))
            await websocket.send_text(json.dumps({
                "type": "chat_message",
                "content": raw_answer_text
            }))
        
    async def handle_command_output(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles the output from a command executed by the client."""
        command = data.get("command")
        output = data.get("output")
        request_id = data.get("request_id")
        
        print(f"Received output for command '{command}' (request_id: {request_id}). Output: {output}")
        
        # The AI would process the command output to determine the next step
        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"Command '{command}' completed. Analyzing output."
        }))
    
    async def handle_chat_message(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles incoming chat messages from the client."""
        # TODO: Enhance this function to process the chat message and determine the next action.
        prompt = data.get("content")
        provider_name = data.get("provider_name", "gemini")
        session_id = data.get("session_id")
        if session_id is None:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: session_id is required for chat messages."
            }))
            return
        session = self.db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: Session with ID {session_id} not found."
            }))
            return
        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        self.db.add(user_message)
        self.db.commit()
        self.db.refresh(user_message)

        path = data.get("path", "")
        if path:
            # If file path is provided, initiate file retrieval process.
            file_request = await self._get_or_create_file_request(session_id, path, prompt)
            await self.send_command(websocket, "list_directory", data={"request_id": str(file_request.id)})
            return
        llm_provider = get_llm_provider(provider_name)
        chat = DspyRagPipeline()
        with dspy.context(lm=llm_provider):
            answer_text = await chat(question=prompt, history=session.messages, context_chunks=[])
                # Save assistant's response
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        self.db.add(assistant_message)
        self.db.commit()
        self.db.refresh(assistant_message)
        
        # 📝 Add this section to send the response back to the client
        # The client-side `handleChatMessage` handler will process this message
        await websocket.send_text(json.dumps({
            "type": "chat_message",
            "content": answer_text
        }))
        logger.info(f"Sent chat response to client: {answer_text}")