import dspy
import json
import uuid
import os
import re
import logging
from datetime import datetime
import ast # Import the Abstract Syntax Trees module
from typing import Dict, Any, Callable, Awaitable, List, Optional
from fastapi import WebSocket,Depends
from sqlalchemy.orm import Session,joinedload
from app.db import models
from app.db import file_retriever_models
from app.db.session import SessionLocal
from app.core.providers.factory import get_llm_provider
from app.core.pipelines.file_selector import CodeRagFileSelector
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.pipelines.question_decider import CodeRagQuestionDecider
# A type hint for our handler functions
MessageHandler = Callable[[WebSocket, Dict[str, Any]], Awaitable[None]]
# Configure logging
logger = logging.getLogger(__name__)
class WorkspaceService:
"""
Manages the full lifecycle of an AI workspace session, including
handling various message types and dispatching them to the correct handlers.
"""
def __init__(self):
# The dispatcher map: keys are message types, values are handler functions
self.message_handlers: Dict[str, MessageHandler] = {
# "select_folder_response": self.handle_select_folder_response,
"list_directory_response": self.handle_list_directory_response,
"file_content_response": self.handle_files_content_response,
"execute_command_response": self.handle_command_output,
"chat_message": self.handle_chat_message,
# Add more message types here as needed
}
# Centralized map of commands that can be sent to the client
self.command_map: Dict[str, Dict[str, Any]] = {
"list_directory": {"type": "list_directory", "description": "Request a list of files and folders in the current directory."},
"get_file_content": {"type": "get_file_content", "description": "Request the content of a specific file."},
"execute_command": {"type": "execute_command", "description": "Request to execute a shell command."},
# Define more commands here
}
# Per-websocket session state management
self.sessions: Dict[str, Dict[str, Any]] = {}
self.db = SessionLocal()
# --- New helper function for reuse ---
async def _update_file_content(self, request_id: uuid.UUID, files_with_content: List[Dict[str, Any]]):
"""
Updates the content of existing file records in the database.
This function is called after the client has sent the content for the
files selected by the AI. It iterates through the provided file data,
finds the corresponding database record, and updates its 'content' field.
"""
if not files_with_content:
logger.warning("No files with content provided to update.")
return
logger.info(f"Starting content update for {len(files_with_content)} files for request {request_id}")
try:
# Fetch all files for the given request ID to build a quick lookup map
retrieved_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(
request_id=request_id
).all()
file_map = {file.file_path: file for file in retrieved_files}
updated_count = 0
for file_info in files_with_content:
file_path = file_info.get("filepath")
content = file_info.get("content")
if not file_path or content is None:
logger.warning("Skipping file with missing filename or content.")
continue
# Find the corresponding record in our map
if file_path in file_map:
db_file = file_map[file_path]
db_file.content = content
updated_count += 1
logger.debug(f"Updated content for file: {file_path}")
else:
logger.warning(f"File {file_path} not found in database for request {request_id}, skipping content update.")
# Commit the changes to the database
self.db.commit()
logger.info(f"Successfully updated content for {updated_count} files.")
except Exception as e:
self.db.rollback()
logger.error(f"Failed to update file content for request {request_id}: {e}")
raise
async def _get_or_create_file_request(self, session_id: int, path: str, prompt: str) -> file_retriever_models.FileRetrievalRequest:
"""
Retrieves an existing FileRetrievalRequest or creates a new one if it doesn't exist.
"""
file_request = self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(
session_id=session_id, directory_path=path
).first()
if not file_request:
file_request = file_retriever_models.FileRetrievalRequest(
session_id=session_id,
question=prompt,
directory_path=path
)
self.db.add(file_request)
self.db.commit()
self.db.refresh(file_request)
else:
# If file_request is found, update it with the latest prompt
file_request.question = prompt
self.db.commit()
self.db.refresh(file_request)
return file_request
async def _get_file_request_by_id(self, request_id: uuid.UUID) -> file_retriever_models.FileRetrievalRequest:
"""
Retrieves a FileRetrievalRequest by its ID.
"""
return self.db.query(file_retriever_models.FileRetrievalRequest).filter_by(id=request_id).first()
async def _store_retrieved_files(self, request_id: uuid.UUID, files: List[Dict[str, Any]]):
"""
Synchronizes the database's retrieved files with the client's file list.
This function compares existing files against new ones and performs
updates, additions, or deletions as necessary.
"""
# 1. Get existing files from the database for this request
existing_files = self.db.query(file_retriever_models.RetrievedFile).filter_by(request_id=request_id).all()
existing_files_map = {file.file_path: file for file in existing_files}
# Keep track of which existing files are also in the incoming list
incoming_file_paths = set()
# 2. Iterate through incoming files to handle updates and additions
for file_info in files:
file_path = file_info.get("path")
last_modified_timestamp_ms = file_info.get("lastModified")
if not file_path or last_modified_timestamp_ms is None:
logger.warning("Skipping file with missing path or timestamp.")
continue
last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp_ms / 1000.0)
incoming_file_paths.add(file_path)
# Check if the file already exists in the database
if file_path in existing_files_map:
db_file = existing_files_map[file_path]
# Compare the last modified timestamps
if last_modified_datetime > db_file.last_updated:
# Case: File has been updated, so override the existing record.
logger.info(f"Updating file {file_path}. New timestamp: {last_modified_datetime}")
db_file.last_updated = last_modified_datetime
# The content remains empty for now, as it will be fetched later.
else:
# Case: File is identical or older, do nothing.
# logger.debug(f"File {file_path} is identical or older, skipping.")
pass
else:
# Case: This is a newly introduced file.
logger.info(f"Adding new file: {file_path}")
new_file = file_retriever_models.RetrievedFile(
request_id=request_id,
file_path=file_path,
file_name=file_info.get("name", ""),
content="", # Content is deliberately left empty.
type="original",
last_updated=last_modified_datetime,
)
self.db.add(new_file)
# 3. Purge non-existing files
# Find files in the database that were not in the incoming list
files_to_purge = [
file for file in existing_files if file.file_path not in incoming_file_paths
]
if files_to_purge:
logger.info(f"Purging {len(files_to_purge)} non-existing files.")
for file in files_to_purge:
self.db.delete(file)
# 4. Commit all changes (updates, additions, and deletions) in a single transaction
self.db.commit()
logger.info("File synchronization complete.")
# def generate_request_id(self) -> str:
# """Generates a unique request ID."""
# return str(uuid.uuid4())
async def _retrieve_by_request_id(self, db: Session, request_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieves a FileRetrievalRequest and all its associated files from the database,
returning the data in a well-formatted JSON-like dictionary.
Args:
db: The SQLAlchemy database session.
request_id: The UUID of the FileRetrievalRequest.
Returns:
A dictionary containing the request and file data, or None if the request is not found.
"""
try:
# Convert string request_id to UUID object for the query
request_uuid = uuid.UUID(request_id)
except ValueError:
print(f"Invalid UUID format for request_id: {request_id}")
return None
# Fetch the request and its related files in a single query using join
request = db.query(file_retriever_models.FileRetrievalRequest).filter(
file_retriever_models.FileRetrievalRequest.id == request_uuid
).options(
# Eagerly load the retrieved_files to avoid N+1 query problem
joinedload(file_retriever_models.FileRetrievalRequest.retrieved_files)
).first()
if not request:
return None
# Build the dictionary to represent the JSON structure
retrieved_data = {
"request_id": str(request.id),
"question": request.question,
"directory_path": request.directory_path,
"session_id": request.session_id,
"created_at": request.created_at.isoformat() if request.created_at else None,
"retrieved_files": []
}
for file in request.retrieved_files:
if file.content:
# For files with content, show the full detailed structure
file_data = {
"file_path": file.file_path,
"content": file.content,
"id": str(file.id),
"name": file.file_name,
"type": file.type,
"last_updated": file.last_updated.isoformat() if file.last_updated else None,
"created_at": file.created_at.isoformat() if file.created_at else None,
}
else:
# For empty files, use a compact representation
file_data = {
"file_path": file.file_path,
"type": file.type
}
retrieved_data["retrieved_files"].append(file_data)
return retrieved_data
async def get_file_content_by_request_id_and_path(self, db: Session, request_id: uuid.UUID, file_path: str) ->str:
"""
Retrieves a FileRetrievalRequest by its ID.
"""
retrievedFile = db.query(file_retriever_models.RetrievedFile).filter_by(request_id = request_id , file_path=file_path).first()
if retrievedFile and retrievedFile.content:
return retrievedFile.content
else:
raise ValueError(f"File with path {file_path} not found for request ID {request_id} or has no content.")
async def _handle_code_change_response(self, db: Session ,request_id: str, code_diff: str) -> List[Dict[str, Any]]:
"""
Parses the diff, retrieves original file content, and returns a structured,
per-file dictionary for the client.
"""
# 1. Split the monolithic code_diff string into per-file diffs.
# This regex splits the diff string while keeping the separators.
per_file_diffs = re.split(r'(?=\ndiff --git a\/)', code_diff)
# 2. Iterate through each per-file diff to get file path and retrieve content.
files_with_diff_and_content = []
for file_diff in per_file_diffs:
if not file_diff.strip():
continue
# Use a regex to find the file path from the "--- a/path" line
path_match = re.search(r'--- a(.*)', file_diff)
if path_match:
file_path = path_match.group(1).strip()
# Retrieve the original content for this specific file.
original_content = await self.get_file_content_by_request_id_and_path(
db,
uuid.UUID(request_id),
file_path
)
# Group the file path, its diff, and its original content.
files_with_diff_and_content.append({
"filepath": file_path,
"diff": file_diff,
"original_content": original_content,
"new_content": self._apply_diff(original_content, file_diff)
})
return files_with_diff_and_content
async def get_files_by_request_id(self, db: Session, request_id: str) -> Optional[List[str]]:
"""
Retrieves all files associated with a FileRetrievalRequest from the database,
returning the data in a list of JSON-like dictionaries.
Args:
db: The SQLAlchemy database session.
request_id: The UUID of the FileRetrievalRequest.
Returns:
A list of dictionaries containing file data, or None if the request is not found.
"""
try:
request_uuid = uuid.UUID(request_id)
except ValueError:
print(f"Invalid UUID format for request_id: {request_id}")
return None
request = db.query(file_retriever_models.FileRetrievalRequest).filter(
file_retriever_models.FileRetrievalRequest.id == request_uuid
).options(
joinedload(file_retriever_models.FileRetrievalRequest.retrieved_files)
).first()
if not request:
return None
retrieved_files = []
for file in request.retrieved_files:
retrieved_files.append(file.file_path)
return retrieved_files
def _format_diff(self, raw_diff: str) -> str:
# Remove Markdown-style code block markers
content = re.sub(r'^```diff\n|```$', '', raw_diff.strip(), flags=re.MULTILINE)
# Unescape common sequences
content = content.encode('utf-8').decode('unicode_escape')
return content
def _apply_diff(self, original_content: str, file_diff: str) -> str:
"""
Applies a unified diff to the original content and returns the new content.
Args:
original_content: The original file content as a single string.
file_diff: The unified diff string.
Returns:
The new content with the diff applied.
"""
original_lines = original_content.splitlines(keepends=True)
diff_lines = file_diff.splitlines(keepends=True)
# Skip diff headers like --- / +++
i = 0
while i < len(diff_lines) and not diff_lines[i].startswith('@@'):
i += 1
if i == len(diff_lines):
return original_content # No hunks to apply
new_content: List[str] = []
orig_idx = 0 # Pointer in original_lines
while i < len(diff_lines):
hunk_header = diff_lines[i]
m = re.match(r'^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', hunk_header)
if not m:
raise ValueError(f"Invalid hunk header: {hunk_header.strip()}")
orig_start = int(m.group(1)) - 1 # line numbers in diff are 1-based
i += 1
# Add unchanged lines before the hunk
while orig_idx < orig_start:
new_content.append(original_lines[orig_idx])
orig_idx += 1
# Process hunk lines
while i < len(diff_lines) and not diff_lines[i].startswith('@@'):
line = diff_lines[i]
if line.startswith(' '):
new_content.append(original_lines[orig_idx])
orig_idx += 1
elif line.startswith('-'):
orig_idx += 1
elif line.startswith('+'):
new_content.append(line[1:]) # Add the new line without '+'
i += 1
# Add the remaining lines from the original
new_content.extend(original_lines[orig_idx:])
return ''.join(new_content)
async def send_command(self, websocket: WebSocket, command_name: str, data: Dict[str, Any] = {}):
if command_name not in self.command_map:
raise ValueError(f"Unknown command: {command_name}")
message_to_send = {
"type": self.command_map[command_name]["type"],
**data,
}
await websocket.send_text(json.dumps(message_to_send))
async def dispatch_message(self, websocket: WebSocket, message: Dict[str, Any]):
"""
Routes an incoming message to the appropriate handler based on its 'type'.
Retrieves session state to maintain context.
"""
message_type = message.get("type")
# request_id = message.get("request_id")
# round_num = message.get("round")
# In a real-world app, you'd retrieve historical data based on request_id or session_id
# For this example, we'll just print it.
# print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})")
# logger.info(f"Received message: {message}")
handler = self.message_handlers.get(message_type)
if handler:
logger.debug(f"Dispatching to handler for message type: {message_type}")
await handler(websocket, message)
else:
print(f"Warning: No handler found for message type: {message_type}")
await websocket.send_text(json.dumps({"type": "error", "content": f"Unknown message type: {message_type}"}))
# async def handle_select_folder_response(self, websocket:WebSocket, data: Dict[str, Any]):
# """Handles the client's response to a select folder response."""
# path = data.get("path")
# request_id = data.get("request_id")
# if request_id is None:
# await websocket.send_text(json.dumps({
# "type": "error",
# "content": "Error: request_id is missing in the response."
# }))
# return
# print(f"Received folder selected (request_id: {request_id}): Path: {path}")
# After a folder is selected, the next step is to list its contents.
# This now uses the send_command helper.
# await self.send_command(websocket, "list_directory", data={"path": path, "request_id": request_id})
async def handle_list_directory_response(self, websocket: WebSocket, data: Dict[str, Any]):
"""Handles the client's response to a list_directory request."""
logger.debug(f"List directory response data: {data}")
files = data.get("files", [])
if not files:
await websocket.send_text(json.dumps({
"type": "error",
"content": "Error: No files found in the directory."
}))
return
request_id = data.get("request_id")
if not request_id:
await websocket.send_text(json.dumps({
"type": "error",
"content": "Error: request_id is missing in the response."
}))
return
file_request = await self._get_file_request_by_id(uuid.UUID(request_id))
if not file_request:
await websocket.send_text(json.dumps({
"type": "error",
"content": f"Error: No matching file request found for request_id {request_id}."
}))
return
await self._store_retrieved_files(request_id=uuid.UUID(request_id), files=files)
provider_name = data.get("provider_name", "gemini")
llm_provider = get_llm_provider(provider_name)
cfs = CodeRagFileSelector()
with dspy.context(lm=llm_provider):
raw_answer_text = await cfs(
question=file_request.question,
retrieved_data = await self.get_files_by_request_id(self.db, request_id=request_id)
)
try:
# Use ast.literal_eval for a safe and reliable parse
answer_text = ast.literal_eval(raw_answer_text)
except (ValueError, SyntaxError) as e:
# Handle cases where the LLM output is not a valid list string.
print(f"Error parsing LLM output: {e}")
answer_text = [] # Default to an empty list to prevent errors.
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Warning: AI's file list could not be parsed. Error: {e}"
}))
return
if len(answer_text) == 0:
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": "AI did not select any files to retrieve content for."
}))
await self.handle_files_content_response(websocket, {"files": [], "request_id": request_id})
return
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"AI selected files: {answer_text}. Now requesting file content."
}))
# After getting the AI's selected files, we send a command to the client to get their content.
await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})
async def handle_files_content_response(self, websocket: WebSocket, data: Dict[str, Any]):
"""
Handles the content of a list of files sent by the client.
"""
files_data: List[Dict[str, str]] = data.get("files", [])
request_id = data.get("request_id")
if not files_data:
print(f"Warning: No files data received for request_id: {request_id}")
else:
print(f"Received content for {len(files_data)} files (request_id: {request_id}).")
await self._update_file_content(request_id=uuid.UUID(request_id), files_with_content=files_data)
# Retrieve the updated context from the database
context_data = await self._retrieve_by_request_id(self.db, request_id=request_id)
if not context_data:
print(f"Error: Context not found for request_id: {request_id}")
await websocket.send_text(json.dumps({
"type": "error",
"content": "An internal error occurred. Please try again."
}))
return
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"AI is analyzing the retrieved files to determine next steps."
}))
# Use the LLM to make a decision
with dspy.context(lm=get_llm_provider(provider_name="gemini")):
crqd = CodeRagQuestionDecider()
raw_answer_text, reasoning, decision, code_diff = await crqd(
question=context_data.get("question", ""),
history="",
retrieved_data=context_data
)
dspy.inspect_history(n=1) # Inspect the last DSPy operation for debugging
if decision == "files":
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"AI decided more files are needed: {raw_answer_text}."
}))
try:
# Use regex to find the JSON content, including any surrounding newlines and code blocks
json_match = re.search(r'\[.*\]', raw_answer_text, re.DOTALL)
if json_match:
# Extract the matched JSON string
json_string = json_match.group(0)
# Use ast.literal_eval for a safe and reliable parse
answer_text = ast.literal_eval(json_string)
if not isinstance(answer_text, list):
raise ValueError("Parsed result is not a list.")
else:
# Fallback if no markdown is found
answer_text = ast.literal_eval(raw_answer_text)
if not isinstance(answer_text, list):
raise ValueError("Parsed result is not a list.")
except (ValueError, SyntaxError) as e:
print(f"Error parsing LLM output: {e}")
answer_text = []
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Warning: AI's file list could not be parsed. Error: {e}"
}))
return
await self.send_command(websocket, "get_file_content", data={"filepaths": answer_text, "request_id": request_id})
elif decision == "code_change":
diffs =await self._handle_code_change_response(db=self.db, request_id=request_id, code_diff=code_diff)
for diff in diffs:
diff["diff"] = self._format_diff(diff.get("diff",""))
payload = json.dumps({
"type": "chat_message",
"content": raw_answer_text,
"reasoning": reasoning,
"dicision" : decision,
"code_diff":diffs
})
logger.info(f"Sending code change response to client: {payload}")
await websocket.send_text(payload)
else: # decision is "answer"
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Answering user's question directly."
}))
await websocket.send_text(json.dumps({
"type": "chat_message",
"content": raw_answer_text
}))
async def handle_command_output(self, websocket: WebSocket, data: Dict[str, Any]):
"""Handles the output from a command executed by the client."""
command = data.get("command")
output = data.get("output")
request_id = data.get("request_id")
print(f"Received output for command '{command}' (request_id: {request_id}). Output: {output}")
# The AI would process the command output to determine the next step
await websocket.send_text(json.dumps({
"type": "thinking_log",
"content": f"Command '{command}' completed. Analyzing output."
}))
async def handle_chat_message(self, websocket: WebSocket, data: Dict[str, Any]):
"""Handles incoming chat messages from the client."""
# TODO: Enhance this function to process the chat message and determine the next action.
prompt = data.get("content")
provider_name = data.get("provider_name", "gemini")
session_id = data.get("session_id")
if session_id is None:
await websocket.send_text(json.dumps({
"type": "error",
"content": "Error: session_id is required for chat messages."
}))
return
session = self.db.query(models.Session).options(
joinedload(models.Session.messages)
).filter(models.Session.id == session_id).first()
if not session:
await websocket.send_text(json.dumps({
"type": "error",
"content": f"Error: Session with ID {session_id} not found."
}))
return
user_message = models.Message(session_id=session_id, sender="user", content=prompt)
self.db.add(user_message)
self.db.commit()
self.db.refresh(user_message)
path = data.get("path", "")
if path:
# If file path is provided, initiate file retrieval process.
file_request = await self._get_or_create_file_request(session_id, path, prompt)
await self.send_command(websocket, "list_directory", data={"request_id": str(file_request.id)})
return
llm_provider = get_llm_provider(provider_name)
chat = DspyRagPipeline()
with dspy.context(lm=llm_provider):
answer_text = await chat(question=prompt, history=session.messages, context_chunks=[])
# Save assistant's response
assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
self.db.add(assistant_message)
self.db.commit()
self.db.refresh(assistant_message)
# 📝 Add this section to send the response back to the client
# The client-side `handleChatMessage` handler will process this message
await websocket.send_text(json.dumps({
"type": "chat_message",
"content": answer_text
}))
logger.info(f"Sent chat response to client: {answer_text}")