Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions gptdiff/applydiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
import hashlib
from collections import defaultdict


def _strip_diff_fence(diff_text: str) -> str:
"""Remove wrapping triple backticks from a diff block."""
stripped = diff_text.strip()
if stripped.startswith("```") and stripped.endswith("```"):
lines = stripped.splitlines()
if lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].startswith("```"):
lines = lines[:-1]
if lines and lines[0].strip().lower() == "diff":
lines = lines[1:]
return "\n".join(lines)
return diff_text

def apply_diff(project_dir, diff_text):
"""
Applies a unified diff (as generated by git diff) to the files in project_dir
Expand Down Expand Up @@ -112,6 +127,7 @@ def apply_patch_to_file(file_path, patch):
return True

# Parse the diff into per-file patches.
diff_text = _strip_diff_fence(diff_text)
file_patches = parse_diff_per_file(diff_text)
if not file_patches:
print("No file patches found in diff.")
Expand Down Expand Up @@ -192,6 +208,8 @@ def parse_diff_per_file(diff_text):
Uses 'b/' prefix detection from git diffs to determine target paths
This doesn't work all the time and needs to be revised with stronger models
"""
diff_text = _strip_diff_fence(diff_text)

def dedup_diffs(diffs):
groups = defaultdict(list)
for key, value in diffs:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_multidiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ def test_fail_diff_through_call_llm(monkeypatch):
def dummy_call_llm(api_key, base_url, model, messages, max_tokens, budget_tokens, temperature):
return DummyResponse(diff_str, prompt_tokens=10, completion_tokens=20, total_tokens=30)

# Patch call_llm in the gptdiff module with our dummy function.
# Patch call_llm and tiktoken.get_encoding to avoid network access.
monkeypatch.setattr("gptdiff.gptdiff.call_llm", dummy_call_llm)

class DummyEnc:
def encode(self, text):
return [0] * len(text)

monkeypatch.setattr("tiktoken.get_encoding", lambda name: DummyEnc())

# generate_diff calls call_llm_for_diff internally, which now uses our dummy_call_llm.
result = generate_diff("dummy environment", "dummy goal", model="test-model")

Expand Down
17 changes: 17 additions & 0 deletions tests/test_parse_diff_per_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,22 @@ def test_parse_diff_per_file_unconventional_header():
assert "+++ game.js" in patch, "Expected patch to include '+++ game.js'"
assert "+let player" in patch, "Expected patch to include added lines"


def test_parse_diff_with_code_fence():
diff_text = """```diff
diff --git a/file.txt b/file.txt
--- a/file.txt
+++ b/file.txt
@@ -1 +1 @@
-old
+new
```"""

result = parse_diff_per_file(diff_text)
assert len(result) == 1
file_path, patch = result[0]
assert file_path == "file.txt"
assert "+new" in patch

if __name__ == '__main__':
unittest.main()