diff --git a/ai-hub/app/core/services/workspace.py b/ai-hub/app/core/services/workspace.py index 2a2cf23..fdfc4f3 100644 --- a/ai-hub/app/core/services/workspace.py +++ b/ai-hub/app/core/services/workspace.py @@ -271,14 +271,19 @@ logger.warning(f"File with path {file_path} not found for request ID {request_id} or has no content.") return "" - async def _handle_code_change_response(self, db: Session ,request_id: str, code_diff: str) -> List[Dict[str, Any]]: + async def _handle_code_change_response(self, db: Session, request_id: str, code_diff: str) -> List[Dict[str, Any]]: """ Parses the diff, retrieves original file content, and returns a structured, per-file dictionary for the client. """ + # Normalize the diff string to ensure consistent splitting, handling cases where + # the separator may be missing a leading newline. + normalized_diff = re.sub(r'(? str: """ Applies a unified diff to the original content and returns the new content. @@ -366,57 +370,68 @@ """ # Handle the case where the original content is empty. if not original_content: - # If the original content is empty, just add the new lines from the diff. new_content: List[str] = [] for line in file_diff.splitlines(keepends=True): - if line.startswith('+'): + if line.startswith('+') and not line.startswith('+++'): new_content.append(line[1:]) return ''.join(new_content) original_lines = original_content.splitlines(keepends=True) diff_lines = file_diff.splitlines(keepends=True) - # Skip diff headers like --- / +++ i = 0 - while i < len(diff_lines) and not diff_lines[i].startswith('@@'): - i += 1 - - if i == len(diff_lines): - return original_content # No hunks to apply - new_content: List[str] = [] - orig_idx = 0 # Pointer in original_lines + orig_idx = 0 while i < len(diff_lines): + # Skip diff headers like --- and +++ + if diff_lines[i].startswith('---') or diff_lines[i].startswith('+++'): + i += 1 + continue + + # Hunk header + if not diff_lines[i].startswith('@@'): + i += 1 + continue + hunk_header = diff_lines[i] m = re.match(r'^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', hunk_header) if not m: raise ValueError(f"Invalid hunk header: {hunk_header.strip()}") - - orig_start = int(m.group(1)) - 1 # line numbers in diff are 1-based + + orig_start = int(m.group(1)) - 1 # convert from 1-based to 0-based index i += 1 - # Add unchanged lines before the hunk + # Copy unchanged lines before this hunk while orig_idx < orig_start: new_content.append(original_lines[orig_idx]) orig_idx += 1 - # Process hunk lines - while i < len(diff_lines) and not diff_lines[i].startswith('@@'): + # Process lines in hunk + while i < len(diff_lines): line = diff_lines[i] - if line.startswith(' '): - new_content.append(original_lines[orig_idx]) - orig_idx += 1 + + if line.startswith('@@'): + # Start of next hunk + break + elif line.startswith(' '): + # Context line + if orig_idx < len(original_lines): + new_content.append(original_lines[orig_idx]) + orig_idx += 1 elif line.startswith('-'): + # Removed line orig_idx += 1 elif line.startswith('+'): - new_content.append(line[1:]) # Add the new line without '+' + # Added line + new_content.append(line[1:]) i += 1 - # Add the remaining lines from the original + # Add remaining lines from original new_content.extend(original_lines[orig_idx:]) return ''.join(new_content) + async def send_command(self, websocket: WebSocket, command_name: str, data: Dict[str, Any] = {}): if command_name not in self.command_map: