import dspy
import json
import uuid
import re
import logging
from datetime import datetime
import ast # Import the Abstract Syntax Trees module
from typing import Dict, Any, Callable, Awaitable, List
from fastapi import WebSocket,Depends
from sqlalchemy.orm import Session,joinedload
from app.db import models
from app.db import file_retriever_models
from app.db.session import SessionLocal
from app.core.providers.factory import get_llm_provider
from app.core.pipelines.file_selector import CodeRagFileSelector
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.pipelines.question_decider import CodeRagQuestionDecider
from app.core.pipelines.context_compressor import StringContextCompressor
from app.core.retrievers.file_retriever import FileRetriever
# A type hint for our handler functions
MessageHandler = Callable[[WebSocket, Dict[str, Any]], Awaitable[None]]
# Configure logging
logger = logging.getLogger(__name__)
class WorkspaceService:
"""
Manages the full lifecycle of an AI workspace session, including
handling various message types and dispatching them to the correct handlers.
"""
def __init__(self):
# The dispatcher map: keys are message types, values are handler functions
self.message_handlers: Dict[str, MessageHandler] = {
# "select_folder_response": self.handle_select_folder_response,
"list_directory_response": self.handle_list_directory_response,
"file_content_response": self.handle_files_content_response,
"execute_command_response": self.handle_command_output,
"chat_message": self.handle_chat_message,
# Add more message types here as needed
}
# Centralized map of commands that can be sent to the client
self.command_map: Dict[str, Dict[str, Any]] = {
"list_directory": {"type": "list_directory", "description": "Request a list of files and folders in the current directory."},
"get_file_content": {"type": "get_file_content", "description": "Request the content of a specific file."},
"execute_command": {"type": "execute_command", "description": "Request to execute a shell command."},
# Define more commands here
}
# Per-websocket session state management
self.sessions: Dict[str, Dict[str, Any]] = {}
self.db = SessionLocal()
self.file_retriever = FileRetriever()
# --- New helper function for reuse ---
async def _update_file_content(self, request_id: uuid.UUID, files_with_content: List[Dict[str, Any]]):
"""
Updates the content of existing file records in the database.
This function is called after the client has sent the content for the
files selected by the AI. It iterates through the provided file data,
finds the corresponding database record, and updates its 'content' field.
"""
if not files_with_content:
logger.warning("No files with content provided to update.")
return
logger.info(f"Starting content update for {len(files_with_content)} files for request {request_id}")
try:
# Fetch all files for the given request ID to build a quick lookup map
retrieved_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(
request_id=request_id
).all()
file_map = {file.file_path: file for file in retrieved_files}
updated_count = 0
for file_info in files_with_content:
file_path = file_info.get("filepath")
content = file_info.get("content")
if not file_path or content is None:
logger.warning("Skipping file with missing filename or content.")
continue
# Find the corresponding record in our map
if file_path in file_map:
db_file = file_map[file_path]
db_file.content = content
updated_count += 1
logger.debug(f"Updated content for file: {file_path}")
else:
logger.warning(f"File {file_path} not found in database for request {request_id}, skipping content update.")
# Commit the changes to the database
self.db.commit()
logger.info(f"Successfully updated content for {updated_count} files.")
except Exception as e:
self.db.rollback()
logger.error(f"Failed to update file content for request {request_id}: {e}")
raise
async def _get_or_create_file_request(self, session_id: int, path: str, prompt: str) -> file_retriever_models.FileRetrievalRequest:
"""
Retrieves an existing FileRetrievalRequest or creates a new one if it doesn't exist.
"""
file_request = self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(
session_id=session_id, directory_path=path
).first()
if not file_request:
file_request = file_retriever_models.FileRetrievalRequest(
session_id=session_id,
question=prompt,
directory_path=path
)
self.db.add(file_request)
self.db.commit()
self.db.refresh(file_request)
else:
# If file_request is found, update it with the latest prompt
file_request.question = prompt
self.db.commit()
self.db.refresh(file_request)
return file_request
async def _get_file_request_by_id(self, request_id: uuid.UUID) -> file_retriever_models.FileRetrievalRequest:
"""
Retrieves a FileRetrievalRequest by its ID.
"""
return self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(id=request_id).first()
async def _store_retrieved_files(self, request_id: uuid.UUID, files: List[Dict[str, Any]]):
"""
Synchronizes the database's retrieved files with the client's file list.
This function compares existing files against new ones and performs
updates, additions, or deletions as necessary.
"""
# 1. Get existing files from the database for this request
existing_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(request_id=request_id).all()
existing_files_map = {file.file_path: file for file in existing_files}
# Keep track of which existing files are also in the incoming list
incoming_file_paths = set()
# 2. Iterate through incoming files to handle updates and additions
for file_info in files:
file_path = file_info.get("path")
last_modified_timestamp_ms = file_info.get("lastModified")
if not file_path or last_modified_timestamp_ms is None:
logger.warning("Skipping file with missing path or timestamp.")
continue
last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp_ms / 1000.0)
incoming_file_paths.add(file_path)
# Check if the file already exists in the database
if file_path in existing_files_map:
db_file = existing_files_map[file_path]
# Compare the last modified timestamps
if last_modified_datetime > db_file.last_updated:
# Case: File has been updated, so override the existing record.
logger.info(f"Updating file {file_path}. New timestamp: {last_modified_datetime}")
db_file.last_updated = last_modified_datetime
# The content remains empty for now, as it will be fetched later.
else:
# Case: File is identical or older, do nothing.
# logger.debug(f"File {file_path} is identical or older, skipping.")
pass
else:
# Case: This is a newly introduced file.
logger.info(f"Adding new file: {file_path}")
new_file = file_retriever_models.RetrievedFile(
request_id=request_id,
file_path=file_path,
file_name=file_info.get("name", ""),
content="", # Content is deliberately left empty.
type="original",
last_updated=last_modified_datetime,
)
self.db.add(new_file)
# 3. Purge non-existing files
# Find files in the database that were not in the incoming list
files_to_purge = [
file for file in existing_files if file.file_path not in incoming_file_paths
]
if files_to_purge:
logger.info(f"Purging {len(files_to_purge)} non-existing files.")
for file in files_to_purge:
self.db.delete(file)
# 4. Commit all changes (updates, additions, and deletions) in a single transaction
self.db.commit()
logger.info("File synchronization complete.")
# def generate_request_id(self) -> str:
# """Generates a unique request ID."""
# return str(uuid.uuid4())
async def send_command(self, websocket: WebSocket, command_name: str, data: Dict[str, Any] = {}):
if command_name not in self.command_map:
raise ValueError(f"Unknown command: {command_name}")
message_to_send = {
"type": self.command_map[command_name]["type"],
**data,
}
await websocket.send_text(json.dumps(message_to_send))
async def dispatch_message(self, websocket: WebSocket, message: Dict[str, Any]):
"""
Routes an incoming message to the appropriate handler based on its 'type'.
Retrieves session state to maintain context.
"""
message_type = message.get("type")
# request_id = message.get("request_id")
# round_num = message.get("round")
# In a real-world app, you'd retrieve historical data based on request_id or session_id
# For this example, we'll just print it.
# print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})")
# logger.info(f"Received message: {message}")
handler = self.message_handlers.get(message_type)
if handler:
logger.debug(f"Dispatching to handler for message type: {message_type}")
await handler(websocket, message)
else:
print(f"Warning: No handler found for message type: {message_type}")
await websocket.send_text(json.dumps({"type": "error", "content": f"Unknown message type: {message_type}"}))
# async def handle_select_folder_response(self, websocket:WebSocket, data: Dict[str, Any]):
# """Handles the client's response to a select folder response."""
# path = data.get("path")
# request_id = data.get("request_id")
# if request_id is None:
# await websocket.send_text(json.dumps({
# "type": "error",
# "content": "Error: request_id is missing in the response."
# }))
# return
# print(f"Received folder selected (request_id: {request_id}): Path: {path}")
# After a folder is selected, the next step is to list its contents.
# This now uses the send_command helper.
# await self.send_command(websocket, "list_directory", data={"path": path, "request_id": request_id})
async def handle_list_directory_response(self, websocket: WebSocket, data: Dict[str, Any]):
"""Handles the client's response to a list_directory request."""
logger.debug(f"List directory response data: {data}")
files = data.get("files", [])
if not files:
await websocket.send_text(json.dumps({
"type": "error",
"content": "Error: No files found in the directory."
}))
return
request_id = data.get("request_id")
if not request_id:
await websocket.send_text(json.dumps({
"type": "error",
"content": "Error: request_id is missing in the response."
}))
return
file_request = await self._get_file_request_by_id(uuid.UUID(request_id))
if not file_request:
await websocket.send_text(json.dumps({
"type": "error",
"content": f"Error: No matching file request found for request_id {request_id}."
}))
return
await self._store_retrieved_files(request_id=uuid.UUID(request_id), files=files)
provider_name = data.get("provider_name", "gemini")
llm_provider = get_llm_provider(provider_name)
cfs = CodeRagFileSelector()
with dspy.context(lm=llm_provider):
raw_answer_text = await cfs(
question=file_request.question,
retrieved_data = self.file_retriever.retrieve_by_request_id(self.db, request_id=request_id)
)
try:
answer_text = ast.literal_eval(raw_answer_text)
if not isinstance(answer_text, list):
raise ValueError("Parsed result is not a list.")
except (ValueError, SyntaxError) as e:
# Handle cases where the LLM output is not a valid list string.
print(f"Error parsing LLM output: {e}")
answer_text = [] # Default to an empty list to prevent errors.
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Warning: AI's file list could not be parsed. Error: {e}"
}))
return
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"AI selected files: {answer_text}. Now requesting file content."
}))
# After getting the AI's selected files, we send a command to the client to get their content.
await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})
async def handle_files_content_response(self, websocket: WebSocket, data: Dict[str, Any]):
"""
Handles the content of a list of files sent by the client.
"""
files_data: List[Dict[str, str]] = data.get("files", [])
request_id = data.get("request_id")
if not files_data:
print(f"Warning: No files data received for request_id: {request_id}")
return
print(f"Received content for {len(files_data)} files (request_id: {request_id}).")
await self._update_file_content(request_id=uuid.UUID(request_id), files_with_content=files_data)
# Retrieve the updated context from the database
context_data = self.file_retriever.retrieve_by_request_id(self.db, request_id=request_id)
if not context_data:
print(f"Error: Context not found for request_id: {request_id}")
await websocket.send_text(json.dumps({
"type": "error",
"content": "An internal error occurred. Please try again."
}))
return
# Use the LLM to make a decision
with dspy.context(lm=get_llm_provider("gemini")):
crqd = CodeRagQuestionDecider()
raw_answer_text, decision, code_diff = await crqd(
question=context_data.get("question", ""),
history="",
retrieved_data=context_data
)
if decision == "files":
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"AI decided more files are needed: {raw_answer_text}."
}))
try:
# The LLM is instructed to provide a JSON list, so we parse it
file_list = json.loads(raw_answer_text)
if not isinstance(file_list, list):
raise ValueError("Parsed result is not a list.")
except (ValueError, json.JSONDecodeError) as e:
print(f"Error parsing LLM output: {e}")
file_list = []
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Warning: AI's file list could not be parsed. Error: {e}"
}))
return
await self.send_command(websocket, "get_file_content", data={"filepaths": file_list, "request_id": request_id})
elif decision == "code_change":
await websocket.send_text(json.dumps({
"type": "chat_message",
"content": raw_answer_text,
"code_diff": code_diff
}))
else: # decision is "answer"
await websocket.send_text(json.dumps({
"type": "chat_message",
"content": raw_answer_text
}))
async def handle_command_output(self, websocket: WebSocket, data: Dict[str, Any]):
"""Handles the output from a command executed by the client."""
command = data.get("command")
output = data.get("output")
request_id = data.get("request_id")
print(f"Received output for command '{command}' (request_id: {request_id}). Output: {output}")
# The AI would process the command output to determine the next step
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Command '{command}' completed. Analyzing output."
}))
async def handle_chat_message(self, websocket: WebSocket, data: Dict[str, Any]):
"""Handles incoming chat messages from the client."""
# TODO: Enhance this function to process the chat message and determine the next action.
prompt = data.get("content")
provider_name = data.get("provider_name", "gemini")
session_id = data.get("session_id")
if session_id is None:
await websocket.send_text(json.dumps({
"type": "error",
"content": "Error: session_id is required for chat messages."
}))
return
session = self.db.query(models.Session).options(
joinedload(models.Session.messages)
).filter(models.Session.id == session_id).first()
if not session:
await websocket.send_text(json.dumps({
"type": "error",
"content": f"Error: Session with ID {session_id} not found."
}))
return
user_message = models.Message(session_id=session_id, sender="user", content=prompt)
self.db.add(user_message)
self.db.commit()
self.db.refresh(user_message)
path = data.get("path", "")
if path:
# If file path is provided, initiate file retrieval process.
file_request = await self._get_or_create_file_request(session_id, path, prompt)
await self.send_command(websocket, "list_directory", data={"request_id": str(file_request.id)})
return
llm_provider = get_llm_provider(provider_name)
chat = DspyRagPipeline()
with dspy.context(lm=llm_provider):
answer_text = await chat(question=prompt, history=session.messages, context_chunks=[])
# Save assistant's response
assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
self.db.add(assistant_message)
self.db.commit()
self.db.refresh(assistant_message)
# 📝 Add this section to send the response back to the client
# The client-side `handleChatMessage` handler will process this message
await websocket.send_text(json.dumps({
"type": "chat_message",
"content": answer_text
}))
logger.info(f"Sent chat response to client: {answer_text}")