Newer
Older
cortex-hub / ai-hub / app / core / services / workspace.py
import dspy
import json
import uuid
import os
import re
import logging
from datetime import datetime
import ast  # Import the Abstract Syntax Trees module
from typing import Dict, Any, Callable, Awaitable, List, Optional
from fastapi import WebSocket,Depends
from sqlalchemy.orm import Session,joinedload
from app.db import models
from app.db import file_retriever_models
from app.db.session import SessionLocal
from app.core.providers.factory import get_llm_provider
# from app.core.pipelines.file_selector import CodeRagFileSelector
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.pipelines.question_decider import CodeRagQuestionDecider
from app.core.services.utils.code_change import CodeChangeHelper
from app.core.pipelines.validator import Validator,TokenLimitExceededError
# A type hint for our handler functions
MessageHandler = Callable[[WebSocket, Dict[str, Any]], Awaitable[None]]
# Configure logging
logger = logging.getLogger(__name__)

class WorkspaceService:
    """
    Manages the full lifecycle of an AI workspace session, including
    handling various message types and dispatching them to the correct handlers.
    """
    def __init__(self):
        # The dispatcher map: keys are message types, values are handler functions
        self.message_handlers: Dict[str, MessageHandler] = {
            # "select_folder_response": self.handle_select_folder_response,
            "list_directory_response": self.handle_list_directory_response,
            "file_content_response": self.handle_files_content_response,
            "execute_command_response": self.handle_command_output,
            "chat_message": self.handle_chat_message,
            # Add more message types here as needed
        }
        # Centralized map of commands that can be sent to the client
        self.command_map: Dict[str, Dict[str, Any]] = {
            "list_directory": {"type": "list_directory", "description": "Request a list of files and folders in the current directory."},
            "get_file_content": {"type": "get_file_content", "description": "Request the content of a specific file."},
            "execute_command": {"type": "execute_command", "description": "Request to execute a shell command."},
            # Define more commands here
        }
        # Per-websocket session state management
        self.sessions: Dict[str, Dict[str, Any]] = {}
        self.db = SessionLocal()

    # --- New helper function for reuse ---
    async def _update_file_content(self, request_id: uuid.UUID, files_with_content: List[Dict[str, Any]]):
        """
        Updates the content of existing file records in the database.

        This function is called after the client has sent the content for the
        files selected by the AI. It iterates through the provided file data,
        finds the corresponding database record, and updates its 'content' field.
        """
        if not files_with_content:
            logger.warning("No files with content provided to update.")
            return

        logger.info(f"Starting content update for {len(files_with_content)} files for request {request_id}")

        try:
            # Fetch all files for the given request ID to build a quick lookup map
            retrieved_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(
                request_id=request_id
            ).all()
            file_map = {file.file_path: file for file in retrieved_files}
            
            updated_count = 0
            for file_info in files_with_content:
                file_path = file_info.get("filepath")
                content = file_info.get("content")

                if not file_path or content is None:
                    logger.warning("Skipping file with missing filename or content.")
                    continue

                # Find the corresponding record in our map
                if file_path in file_map:
                    db_file = file_map[file_path]
                    db_file.content = content
                    updated_count += 1
                    logger.debug(f"Updated content for file: {file_path}")
                else:
                    logger.warning(f"File {file_path} not found in database for request {request_id}, skipping content update.")
            
            # Commit the changes to the database
            self.db.commit()
            logger.info(f"Successfully updated content for {updated_count} files.")
        
        except Exception as e:
            self.db.rollback()
            logger.error(f"Failed to update file content for request {request_id}: {e}")
            raise

    async def _get_or_create_file_request(self, session_id: int, path: str, prompt: str) -> file_retriever_models.FileRetrievalRequest:
        """
        Retrieves an existing FileRetrievalRequest or creates a new one if it doesn't exist.
        """
        file_request = self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(
            session_id=session_id, directory_path=path
        ).first()

        if not file_request:
            file_request = file_retriever_models.FileRetrievalRequest(
                session_id=session_id,
                question=prompt,
                directory_path=path
            )
            self.db.add(file_request)
            self.db.commit()
            self.db.refresh(file_request)
        else:
            # If file_request is found, update it with the latest prompt
            file_request.question = prompt
            self.db.commit()
            self.db.refresh(file_request)
        return file_request
    
    async def _get_file_request_by_id(self, request_id: uuid.UUID) -> file_retriever_models.FileRetrievalRequest:
        """
        Retrieves a FileRetrievalRequest by its ID.
        """
        return self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(id=request_id).first()

    async def _store_retrieved_files(self, request_id: uuid.UUID, files: List[Dict[str, Any]]):
        """
        Synchronizes the database's retrieved files with the client's file list.

        This function compares existing files against new ones and performs
        updates, additions, or deletions as necessary.
        """
        # 1. Get existing files from the database for this request
        existing_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(request_id=request_id).all()
        existing_files_map = {file.file_path: file for file in existing_files}
        
        # Keep track of which existing files are also in the incoming list
        incoming_file_paths = set()
        
        # 2. Iterate through incoming files to handle updates and additions
        for file_info in files:
            file_path = file_info.get("path")
            last_modified_timestamp_ms = file_info.get("lastModified")

            if not file_path or last_modified_timestamp_ms is None:
                logger.warning("Skipping file with missing path or timestamp.")
                continue
                
            last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp_ms / 1000.0)
            incoming_file_paths.add(file_path)

            # Check if the file already exists in the database
            if file_path in existing_files_map:
                db_file = existing_files_map[file_path]
                # Compare the last modified timestamps
                if last_modified_datetime > db_file.last_updated:
                    # Case: File has been updated, so override the existing record.
                    logger.info(f"Updating file {file_path}. New timestamp: {last_modified_datetime}")
                    db_file.last_updated = last_modified_datetime
                    # The content remains empty for now, as it will be fetched later.
                else:
                    # Case: File is identical or older, do nothing.
                    # logger.debug(f"File {file_path} is identical or older, skipping.")
                    pass
                    
            else:
                # Case: This is a newly introduced file.
                logger.info(f"Adding new file: {file_path}")
                new_file = file_retriever_models.RetrievedFile(
                    request_id=request_id,
                    file_path=file_path,
                    file_name=file_info.get("name", ""),
                    content="",  # Content is deliberately left empty.
                    type="original",
                    last_updated=last_modified_datetime,
                )
                self.db.add(new_file)

        # 3. Purge non-existing files
        # Find files in the database that were not in the incoming list
        files_to_purge = [
            file for file in existing_files if file.file_path not in incoming_file_paths
        ]
        if files_to_purge:
            logger.info(f"Purging {len(files_to_purge)} non-existing files.")
            for file in files_to_purge:
                self.db.delete(file)

        # 4. Commit all changes (updates, additions, and deletions) in a single transaction
        self.db.commit()
        logger.info("File synchronization complete.")

    # def generate_request_id(self) -> str:
    #     """Generates a unique request ID."""
    #     return str(uuid.uuid4())

    async def _retrieve_by_request_id(self, db: Session, request_id: str) -> Optional[Dict[str, Any]]:
        """
        Retrieves a FileRetrievalRequest and all its associated files from the database,
        returning the data in a well-formatted JSON-like dictionary.

        Args:
            db: The SQLAlchemy database session.
            request_id: The UUID of the FileRetrievalRequest.

        Returns:
            A dictionary containing the request and file data, or None if the request is not found.
        """
        try:
            # Convert string request_id to UUID object for the query
            request_uuid = uuid.UUID(request_id)
        except ValueError:
            print(f"Invalid UUID format for request_id: {request_id}")
            return None

        # Fetch the request and its related files in a single query using join
        request = db.query(file_retriever_models.FileRetrievalRequest).filter(
            file_retriever_models.FileRetrievalRequest.id == request_uuid
        ).options(
            # Eagerly load the retrieved_files to avoid N+1 query problem
            joinedload(file_retriever_models.FileRetrievalRequest.retrieved_files)
        ).first()

        if not request:
            return None

        # Build the dictionary to represent the JSON structure
        retrieved_data = {
            "request_id": str(request.id),
            "question": request.question,
            "directory_path": request.directory_path,
            "session_id": request.session_id,
            "created_at": request.created_at.isoformat() if request.created_at else None,
            "retrieved_files": []
        }

        for file in request.retrieved_files:
            if file.content:
                # For files with content, show the full detailed structure
                file_data = {
                    "file_path": file.file_path,
                    "content": file.content,
                    "id": str(file.id),
                    "name": file.file_name,
                    "type": file.type,
                    "last_updated": file.last_updated.isoformat() if file.last_updated else None,
                    "created_at": file.created_at.isoformat() if file.created_at else None,
                }
            else:
                # For empty files, use a compact representation
                file_data = {
                    "file_path": file.file_path,
                    "type": file.type
                }
            retrieved_data["retrieved_files"].append(file_data)

        return retrieved_data



    async def get_file_content_by_request_id_and_path(self, db: Session, request_id: uuid.UUID, file_path: str) ->str:
        """
        Retrieves a FileRetrievalRequest by its ID.
        """
        retrievedFile =  db.query(file_retriever_models.RetrievedFile).filter_by(request_id = request_id , file_path=file_path).first()
        if retrievedFile and retrievedFile.content:
            return retrievedFile.content
        else:
            logger.warning(f"File with path {file_path} not found for request ID {request_id} or has no content.")
            return ""
    
    # async def _handle_code_change_response(self, db: Session, request_id: str, code_diff: str) -> List[Dict[str, Any]]: 
    #     """
    #     Parses the diff, retrieves original file content, and returns a structured,
    #     per-file dictionary for the client.
    #     """
    #     # Normalize the diff string to ensure consistent splitting, handling cases where 
    #     # the separator may be missing a leading newline.
    #     normalized_diff = re.sub(r'(?<!^)(?<!\n)--- a\/', '\n--- a/', code_diff)

    #     # 1. Split the monolithic code_diff string into per-file diffs.
    #     # The updated regex is more robust and handles both standard "diff --git" headers and 
    #     # non-standard file separators like "--- a/".
    #     per_file_diffs = re.split(r'(?=\ndiff --git a\/|\n--- a\/)', normalized_diff)
        
    #     # 2. Iterate through each per-file diff to get file path and retrieve content.
    #     files_with_diff_and_content = []
        
    #     for file_diff in per_file_diffs:
    #         if not file_diff.strip():
    #             continue

    #         # Use a regex to find the file path from the "--- a/path" line
    #         path_match = re.search(r'--- a(.*)', file_diff)
    #         if path_match:
    #             file_path = path_match.group(1).strip()
                
    #             # Retrieve the original content for this specific file.
    #             original_content = await self.get_file_content_by_request_id_and_path(
    #                 db, 
    #                 uuid.UUID(request_id), 
    #                 file_path
    #             )
                
    #             # Group the file path, its diff, and its original content.
    #             files_with_diff_and_content.append({
    #                 "filepath": file_path,
    #                 "diff": file_diff,
    #                 "original_content": original_content,
    #                 "new_content": self._apply_diff(original_content, file_diff)
    #             })
    #     return files_with_diff_and_content
        
    
    async def get_files_by_request_id(self, db: Session, request_id: str) -> Optional[List[str]]:
        """
        Retrieves all files associated with a FileRetrievalRequest from the database,
        returning the data in a list of JSON-like dictionaries.

        Args:
            db: The SQLAlchemy database session.
            request_id: The UUID of the FileRetrievalRequest.

        Returns:
            A list of dictionaries containing file data, or None if the request is not found.
        """
        try:
            request_uuid = uuid.UUID(request_id)
        except ValueError:
            print(f"Invalid UUID format for request_id: {request_id}")
            return None

        request = db.query(file_retriever_models.FileRetrievalRequest).filter(
            file_retriever_models.FileRetrievalRequest.id == request_uuid
        ).options(
            joinedload(file_retriever_models.FileRetrievalRequest.retrieved_files)
        ).first()

        if not request:
            return None

        retrieved_files = []

        for file in request.retrieved_files:
            retrieved_files.append(file.file_path)

        return retrieved_files
    
    # def _format_diff(self, raw_diff: str) -> str:
    #     # Remove Markdown-style code block markers
    #     content = re.sub(r'^```diff\n|```$', '', raw_diff.strip(), flags=re.MULTILINE)
        
    #     # Unescape common sequences
    #     content = content.encode('utf-8').decode('unicode_escape')

    #     return content

    # def _apply_diff(self, original_content: str, file_diff: str) -> str:
    #     """
    #     Applies a unified diff to the original content and returns the new content.

    #     Args:
    #         original_content: The original file content as a single string.
    #         file_diff: The unified diff string.

    #     Returns:
    #         The new content with the diff applied.
    #     """
    #     # Handle the case where the original content is empty.
    #     if not original_content:
    #         new_content: List[str] = []
    #         for line in file_diff.splitlines(keepends=True):
    #             if line.startswith('+') and not line.startswith('+++'):
    #                 new_content.append(line[1:])
    #         return ''.join(new_content)

    #     original_lines = original_content.splitlines(keepends=True)
    #     diff_lines = file_diff.splitlines(keepends=True)

    #     i = 0
    #     new_content: List[str] = []
    #     orig_idx = 0

    #     while i < len(diff_lines):
    #         # Skip diff headers like --- and +++
    #         if diff_lines[i].startswith('---') or diff_lines[i].startswith('+++'):
    #             i += 1
    #             continue

    #         # Hunk header
    #         if not diff_lines[i].startswith('@@'):
    #             i += 1
    #             continue

    #         hunk_header = diff_lines[i]
    #         m = re.match(r'^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', hunk_header)
    #         if not m:
    #             raise ValueError(f"Invalid hunk header: {hunk_header.strip()}")

    #         orig_start = int(m.group(1)) - 1  # convert from 1-based to 0-based index
    #         i += 1

    #         # Copy unchanged lines before this hunk
    #         while orig_idx < orig_start:
    #             new_content.append(original_lines[orig_idx])
    #             orig_idx += 1

    #         # Process lines in hunk
    #         while i < len(diff_lines):
    #             line = diff_lines[i]

    #             if line.startswith('@@'):
    #                 # Start of next hunk
    #                 break
    #             elif line.startswith(' '):
    #                 # Context line
    #                 if orig_idx < len(original_lines):
    #                     new_content.append(original_lines[orig_idx])
    #                     orig_idx += 1
    #             elif line.startswith('-'):
    #                 # Removed line
    #                 orig_idx += 1
    #             elif line.startswith('+'):
    #                 # Added line
    #                 new_content.append(line[1:])
    #             i += 1

    #     # Add remaining lines from original
    #     new_content.extend(original_lines[orig_idx:])

    #     return ''.join(new_content)

    
    async def send_command(self, websocket: WebSocket, command_name: str, data: Dict[str, Any] = {}):
        if command_name not in self.command_map:
            raise ValueError(f"Unknown command: {command_name}")
        
        message_to_send = {
            "type": self.command_map[command_name]["type"],
            **data,
        }
        
        await websocket.send_text(json.dumps(message_to_send))
        
    async def dispatch_message(self, websocket: WebSocket, message: Dict[str, Any]):
        """
        Routes an incoming message to the appropriate handler based on its 'type'.
        Retrieves session state to maintain context.
        """
        message_type = message.get("type")
        # request_id = message.get("request_id")
        # round_num = message.get("round")
        
        # In a real-world app, you'd retrieve historical data based on request_id or session_id
        # For this example, we'll just print it.
        # print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})")
        # logger.info(f"Received message: {message}")
        
        handler = self.message_handlers.get(message_type)
        if handler:
            logger.debug(f"Dispatching to handler for message type: {message_type}")
            await handler(websocket, message)
        else:
            print(f"Warning: No handler found for message type: {message_type}")
            await websocket.send_text(json.dumps({"type": "error", "content": f"Unknown message type: {message_type}"}))

    # async def handle_select_folder_response(self, websocket:WebSocket, data: Dict[str, Any]):
    #     """Handles the client's response to a select folder response."""
    #     path = data.get("path")
    #     request_id = data.get("request_id")
    #     if request_id is None:
    #         await websocket.send_text(json.dumps({
    #             "type": "error",
    #             "content": "Error: request_id is missing in the response."
    #         }))
    #         return
    #     print(f"Received folder selected (request_id: {request_id}): Path: {path}")
        
        # After a folder is selected, the next step is to list its contents.
        # This now uses the send_command helper.
        # await self.send_command(websocket, "list_directory", data={"path": path, "request_id": request_id})

    async def handle_list_directory_response(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles the client's response to a list_directory request."""
        logger.debug(f"List directory response data: {data}")
        files = data.get("files", [])
        if not files:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: No files found in the directory."
            }))
            return
        request_id = data.get("request_id")
        if not request_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: request_id is missing in the response."
            }))
            return
        try:
            Validator().precheck_tokensize(files)
        except TokenLimitExceededError as e:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: {e}"
                }))
            return 
        file_request = await self._get_file_request_by_id(uuid.UUID(request_id))
        if not file_request:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: No matching file request found for request_id {request_id}."
            }))
            return
        await self._store_retrieved_files(request_id=uuid.UUID(request_id), files=files)
        await self.handle_files_content_response(websocket, {"files": [], "request_id": request_id, "session_id": file_request.session_id})

        # session = self.db.query(models.Session).options(
        #     joinedload(models.Session.messages)
        # ).filter(models.Session.id == file_request.session_id).first()

    #     provider_name = data.get("provider_name", "gemini")
    #     llm_provider = get_llm_provider(provider_name)
    #     cfs = CodeRagFileSelector()
    #     retrieved_data = await self._retrieve_by_request_id(self.db, request_id=request_id)
    #     with dspy.context(lm=llm_provider):
    #         raw_answer_text ,reasoning = await cfs(
    #             question=file_request.question,
    #             retrieved_data=retrieved_data,
    #             history=session.messages
    #         )
    #     assistant_message = models.Message(session_id=file_request.session_id, sender="assistant", content=f'${reasoning}: ${raw_answer_text}')
    #     self.db.add(assistant_message)
    #     self.db.commit()
    #     dspy.inspect_history(n=1)  # Inspect the last DSPy operation for debugging

    #     try:
    #         # Use ast.literal_eval for a safe and reliable parse
    #         answer_text = ast.literal_eval(raw_answer_text)
    #     except (ValueError, SyntaxError) as e:
    # # Handle cases where the LLM output is not a valid list string.
    #         print(f"Error parsing LLM output: {e}")
    #         answer_text = []  # Default to an empty list to prevent errors.
    #         await websocket.send_text(json.dumps({
    #             "type": "thinking_log",
    #             "content": f"Warning: AI's file list could not be parsed. Error: {e}"
    #         }))
    #         return
        
    #     if len(answer_text) == 0:
    #         await websocket.send_text(json.dumps({
    #             "type": "thinking_log",
    #             "content": "AI did not select any files to retrieve content for."
    #         }))
    #         await self.handle_files_content_response(websocket, {"files": [], "request_id": request_id, "session_id": file_request.session_id})
    #         return
        
    #     await websocket.send_text(json.dumps({
    #         "type": "thinking_log",
    #         "content": f"AI selected files: {answer_text}. Now requesting file content."
    #     }))
            
    #     # After getting the AI's selected files, we send a command to the client to get their content.
    #     await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})

    async def handle_files_content_response(self, websocket: WebSocket, data: Dict[str, Any]):
        """
        Handles the content of a list of files sent by the client.
        """
        files_data: List[Dict[str, str]] = data.get("files", [])
        request_id = data.get("request_id")
        session_id = data.get("session_id")
        
        if not request_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: request_id is required to process file content."
            }))
            return

        if not session_id:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: session_id is required to process file content."
            }))
            return
        
        if not files_data:
            print(f"Warning: No files data received for request_id: {request_id}")
        else:          
            print(f"Received content for {len(files_data)} files (request_id: {request_id}).")
            await self._update_file_content(request_id=uuid.UUID(request_id), files_with_content=files_data)
        
        # Retrieve the updated context from the database
        context_data = await self._retrieve_by_request_id(self.db, request_id=request_id)
        
        if not context_data:
            print(f"Error: Context not found for request_id: {request_id}")
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "An internal error occurred. Please try again."
            }))
            return
        
        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"AI is analyzing the retrieved files to determine next steps."
        }))

        session = self.db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()

        # Use the LLM to make a decision
        with dspy.context(lm=get_llm_provider(provider_name="gemini")):
            crqd = CodeRagQuestionDecider()
            original_question= context_data.get("question", "")
            try:
                raw_answer_text, reasoning, decision = await crqd(
                    question=original_question,
                    history=session.messages,
                    retrieved_data=context_data
                )
            except ValueError as e:
                await websocket.send_text(json.dumps({
                    "type": "error",
                    "content": f"Failed to process AI decision request. Error: {e}"
                }))
                return
        # dspy.inspect_history(n=1)
        
        if decision == "answer":
            # Handle regular answer
            assistant_message = models.Message(session_id=session_id, sender="assistant", content=raw_answer_text)
            self.db.add(assistant_message)
            self.db.commit()
            self.db.refresh(assistant_message)
            await websocket.send_text(json.dumps({
                "type": "chat_message",
                "content": raw_answer_text,
                "reasoning": reasoning
            }))

        elif decision == "files": 
            # Handle file retrieval request
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"AI decided files are needed: {raw_answer_text}."
            }))
            try:
                json_match = re.search(r'\[.*\]', raw_answer_text, re.DOTALL)
                if json_match:
                    json_string = json_match.group(0)
                    answer_text = ast.literal_eval(json_string)
                    if not isinstance(answer_text, list):
                        raise ValueError("Parsed result is not a list.")
                else:
                    answer_text = ast.literal_eval(raw_answer_text)
                    if not isinstance(answer_text, list):
                        raise ValueError("Parsed result is not a list.")
            except (ValueError, SyntaxError) as e:
                print(f"Error parsing LLM output: {e}")
                answer_text = []
                await websocket.send_text(json.dumps({
                    "type": "thinking_log",
                    "content": f"Warning: AI's file list could not be parsed. Error: {e}"
                }))
                return
            assistant_message = models.Message(session_id=session_id, sender="assistant", content=f"{reasoning}\n Request Files: {answer_text} ")
            self.db.add(assistant_message)
            self.db.commit()
            self.db.refresh(assistant_message)
            await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})
        
        elif decision == "code_change":
            # Handle code change request
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": "AI is generating the necessary code changes. This may take a moment."
            }))
            
            try:
                # The input_data is a JSON string of code change instructions
                cch = CodeChangeHelper(db=self.db, provider_name="gemini", original_question=original_question, input_data=raw_answer_text, reasoning = reasoning,request_id= uuid.UUID(request_id))
                
                # Use the CodeChangeHelper to process all code changes
                final_changes = await cch.process(websocket=websocket)
                
                # Send the final processed changes to the client
                payload = json.dumps({
                    "type": "code_change",
                    "code_changes": final_changes,
                    "content": "Completed all requested code changes.",
                    "done": True
                })
                logger.info(f"Sending code change response to client: {payload}")
                await websocket.send_text(payload)

            except (json.JSONDecodeError, ValueError) as e:
                logger.error(f"Error processing code changes: {e}")
                await websocket.send_text(json.dumps({
                    "type": "error",
                    "content": f"Failed to process code change request. Error: {e}"
                }))
                
        else: # Fallback for any other decision
            await websocket.send_text(json.dumps({
                "type": "thinking_log",
                "content": f"Answering user's question directly."
            }))
            await websocket.send_text(json.dumps({
                "type": "chat_message",
                "content": raw_answer_text,
                "reasoning": reasoning
            }))

        
    async def handle_command_output(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles the output from a command executed by the client."""
        command = data.get("command")
        output = data.get("output")
        request_id = data.get("request_id")
        
        print(f"Received output for command '{command}' (request_id: {request_id}). Output: {output}")
        
        # The AI would process the command output to determine the next step
        await websocket.send_text(json.dumps({
            "type": "thinking_log",
            "content": f"Command '{command}' completed. Analyzing output."
        }))
    
    async def handle_chat_message(self, websocket: WebSocket, data: Dict[str, Any]):
        """Handles incoming chat messages from the client."""
        # TODO: Enhance this function to process the chat message and determine the next action.
        prompt = data.get("content")
        provider_name = data.get("provider_name", "gemini")
        session_id = data.get("session_id")
        if session_id is None:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": "Error: session_id is required for chat messages."
            }))
            return
        session = self.db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            await websocket.send_text(json.dumps({
                "type": "error",
                "content": f"Error: Session with ID {session_id} not found."
            }))
            return
        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        self.db.add(user_message)
        self.db.commit()
        self.db.refresh(user_message)

        path = data.get("path", "")
        if path:
            # If file path is provided, initiate file retrieval process.
            file_request = await self._get_or_create_file_request(session_id, path, prompt)
            await self.send_command(websocket, "list_directory", data={"request_id": str(file_request.id)})
            return
        llm_provider = get_llm_provider(provider_name)
        chat = DspyRagPipeline()
        with dspy.context(lm=llm_provider):
            answer_text = await chat(question=prompt, history=session.messages, context_chunks=[])
                # Save assistant's response
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        self.db.add(assistant_message)
        self.db.commit()
        self.db.refresh(assistant_message)
        
        # 📝 Add this section to send the response back to the client
        # The client-side `handleChatMessage` handler will process this message
        await websocket.send_text(json.dumps({
            "type": "chat_message",
            "content": answer_text
        }))
        logger.info(f"Sent chat response to client: {answer_text}")