Skip to content

Commit

Permalink
Merge pull request #1599 from kjughx/revert-hunk
Browse files Browse the repository at this point in the history
Add ability to revert hunk
  • Loading branch information
CKolkey authored Dec 17, 2024
2 parents 6691c4e + c769686 commit 9dc5807
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 48 deletions.
2 changes: 1 addition & 1 deletion lua/neogit/buffers/commit_view/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function M:open(kind)
end),
[popups.mapping_for("RemotePopup")] = popups.open("remote"),
[popups.mapping_for("RevertPopup")] = popups.open("revert", function(p)
p { commits = { self.commit_info.oid } }
p { commits = { self.commit_info.oid }, item = self.buffer.ui:get_hunk_or_filename_under_cursor() }
end),
[popups.mapping_for("ResetPopup")] = popups.open("reset", function(p)
p { commit = self.commit_info.oid }
Expand Down
25 changes: 15 additions & 10 deletions lua/neogit/buffers/status/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ M.v_discard = function(self)
for _, hunk in ipairs(hunks) do
table.insert(invalidated_diffs, "*:" .. item.name)
table.insert(patches, function()
local patch = git.index.generate_patch(item, hunk, hunk.from, hunk.to, true)
local patch =
git.index.generate_patch(hunk, { from = hunk.from, to = hunk.to, reverse = true })

logger.debug(("Discarding Patch: %s"):format(patch))

Expand Down Expand Up @@ -231,7 +232,7 @@ M.v_stage = function(self)

if #hunks > 0 then
for _, hunk in ipairs(hunks) do
table.insert(patches, git.index.generate_patch(item, hunk, hunk.from, hunk.to))
table.insert(patches, git.index.generate_patch(hunk.hunk, { from = hunk.from, to = hunk.to }))
end
else
if section.name == "unstaged" then
Expand Down Expand Up @@ -281,7 +282,10 @@ M.v_unstage = function(self)

if #hunks > 0 then
for _, hunk in ipairs(hunks) do
table.insert(patches, git.index.generate_patch(item, hunk, hunk.from, hunk.to, true))
table.insert(
patches,
git.index.generate_patch(hunk, { from = hunk.from, to = hunk.to, reverse = true })
)
end
else
table.insert(files, item.escaped_path)
Expand Down Expand Up @@ -781,17 +785,16 @@ M.n_discard = function(self)
local hunk =
self.buffer.ui:item_hunks(selection.item, selection.first_line, selection.last_line, false)[1]

local patch = git.index.generate_patch(selection.item, hunk, hunk.from, hunk.to, true)
local patch = git.index.generate_patch(hunk, { from = hunk.from, to = hunk.to, reverse = true })

if section == "untracked" then
message = "Discard hunk?"
action = function()
local hunks =
self.buffer.ui:item_hunks(selection.item, selection.first_line, selection.last_line, false)

local patch = git.index.generate_patch(selection.item, hunks[1], hunks[1].from, hunks[1].to, true)

git.index.apply(patch, { reverse = true })
local patch =
git.index.generate_patch(hunks[1], { from = hunks[1].from, to = hunks[1].to, reverse = true })
git.index.apply(patch, { reverse = true })
end
refresh = { update_diffs = { "untracked:" .. selection.item.name } }
Expand Down Expand Up @@ -1057,7 +1060,7 @@ M.n_stage = function(self)
local item = self.buffer.ui:get_item_under_cursor()
assert(item, "Item cannot be nil")

local patch = git.index.generate_patch(item, stagable.hunk, stagable.hunk.from, stagable.hunk.to)
local patch = git.index.generate_patch(stagable.hunk)
git.index.apply(patch, { cached = true })
self:dispatch_refresh({ update_diffs = { "*:" .. item.escaped_path } }, "n_stage")
elseif stagable.filename then
Expand Down Expand Up @@ -1131,8 +1134,10 @@ M.n_unstage = function(self)
if unstagable.hunk then
local item = self.buffer.ui:get_item_under_cursor()
assert(item, "Item cannot be nil")
local patch =
git.index.generate_patch(item, unstagable.hunk, unstagable.hunk.from, unstagable.hunk.to, true)
local patch = git.index.generate_patch(
unstagable.hunk,
{ from = unstagable.hunk.from, to = unstagable.hunk.to, reverse = true }
)

git.index.apply(patch, { cached = true, reverse = true })
self:dispatch_refresh({ update_diffs = { "*:" .. item.escaped_path } }, "n_unstage")
Expand Down
7 changes: 7 additions & 0 deletions lua/neogit/lib/git/diff.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ local sha256 = vim.fn.sha256
---@field deletions number
---
---@class Hunk
---@field file string
---@field index_from number
---@field index_len number
---@field diff_from number
---@field diff_to number
---@field first number First line number in buffer
---@field last number Last line number in buffer
---@field lines string[]
---
---@class DiffStagedStats
---@field summary string
Expand Down Expand Up @@ -224,6 +226,11 @@ local function parse_diff(raw_diff, raw_stats)
local file = build_file(header, kind)
local stats = parse_diff_stats(raw_stats or {})

util.map(hunks, function(hunk)
hunk.file = file
return hunk
end)

return { ---@type Diff
kind = kind,
lines = lines,
Expand Down
36 changes: 13 additions & 23 deletions lua/neogit/lib/git/index.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@ local util = require("neogit.lib.util")
local M = {}

---Generates a patch that can be applied to index
---@param item any
---@param hunk Hunk
---@param from number
---@param to number
---@param reverse boolean|nil
---@param opts table|nil
---@return string
function M.generate_patch(item, hunk, from, to, reverse)
reverse = reverse or false
function M.generate_patch(hunk, opts)
opts = opts or { reverse = false, cached = false, index = false }
local reverse = opts.reverse

if not from and not to then
from = hunk.diff_from + 1
to = hunk.diff_to
end
local from = opts.from or 1
local to = opts.to or (hunk.diff_to - hunk.diff_from)

assert(from <= to, string.format("from must be less than or equal to to %d %d", from, to))
if from > to then
Expand All @@ -29,35 +25,31 @@ function M.generate_patch(item, hunk, from, to, reverse)
local len_start = hunk.index_len
local len_offset = 0

-- + 1 skips the hunk header, since we construct that manually afterwards
-- TODO: could use `hunk.lines` instead if this is only called with the `SelectedHunk` type
for k = hunk.diff_from + 1, hunk.diff_to do
local v = item.diff.lines[k]
local operand, line = v:match("^([+ -])(.*)")

for k, line in pairs(hunk.lines) do
local operand, l = line:match("^([+ -])(.*)")
if operand == "+" or operand == "-" then
if from <= k and k <= to then
len_offset = len_offset + (operand == "+" and 1 or -1)
table.insert(diff_content, v)
table.insert(diff_content, line)
else
-- If we want to apply the patch normally, we need to include every `-` line we skip as a normal line,
-- since we want to keep that line.
if not reverse then
if operand == "-" then
table.insert(diff_content, " " .. line)
table.insert(diff_content, " " .. l)
end
-- If we want to apply the patch in reverse, we need to include every `+` line we skip as a normal line, since
-- it's unchanged as far as the diff is concerned and should not be reversed.
-- We also need to adapt the original line offset based on if we skip or not
elseif reverse then
if operand == "+" then
table.insert(diff_content, " " .. line)
table.insert(diff_content, " " .. l)
end
len_start = len_start + (operand == "-" and -1 or 1)
end
end
else
table.insert(diff_content, v)
table.insert(diff_content, line)
end
end

Expand All @@ -68,9 +60,7 @@ function M.generate_patch(item, hunk, from, to, reverse)
)

local worktree_root = git.repo.worktree_root

assert(item.absolute_path, "Item is not a path")
local path = Path:new(item.absolute_path):make_relative(worktree_root)
local path = Path:new(hunk.file):make_relative(worktree_root)

table.insert(diff_content, 1, string.format("+++ b/%s", path))
table.insert(diff_content, 1, string.format("--- a/%s", path))
Expand Down
5 changes: 5 additions & 0 deletions lua/neogit/lib/git/revert.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ function M.commits(commits, args)
end
end

function M.hunk(hunk, _)
local patch = git.index.generate_patch(hunk, { reverse = true })
git.index.apply(patch, { reverse = true })
end

function M.continue()
git.cli.revert.continue.call()
end
Expand Down
11 changes: 2 additions & 9 deletions lua/neogit/lib/ui/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -182,25 +182,19 @@ function Ui:item_hunks(item, first_line, last_line, partial)

if not item.folded and item.diff.hunks then
for _, h in ipairs(item.diff.hunks) do
if h.first <= last_line and h.last >= first_line then
if h.first <= first_line and h.last >= last_line then
local from, to

if partial then
local cursor_offset = first_line - h.first
local length = last_line - first_line

from = h.diff_from + cursor_offset
from = first_line - h.first
to = from + length
else
from = h.diff_from + 1
to = h.diff_to
end

local hunk_lines = {}
for i = from, to do
table.insert(hunk_lines, item.diff.lines[i])
end

-- local conflict = false
-- for _, n in ipairs(conflict_markers) do
-- if from <= n and n <= to then
Expand All @@ -214,7 +208,6 @@ function Ui:item_hunks(item, first_line, last_line, partial)
to = to,
__index = h,
hunk = h,
lines = hunk_lines,
-- conflict = conflict,
}

Expand Down
4 changes: 4 additions & 0 deletions lua/neogit/popups/revert/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ function M.changes(popup)
end
end

function M.hunk(popup)
git.revert.hunk(popup.state.env.item.hunk, popup:get_arguments())
end

function M.continue()
git.revert.continue()
end
Expand Down
1 change: 1 addition & 0 deletions lua/neogit/popups/revert/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ function M.create(env)
:group_heading("Revert")
:action_if(not in_progress, "v", "Commit(s)", actions.commits)
:action_if(not in_progress, "V", "Changes", actions.changes)
:action_if(((not in_progress) and env.item ~= nil), "h", "Hunk", actions.hunk)
:action_if(in_progress, "v", "continue", actions.continue)
:action_if(in_progress, "s", "skip", actions.skip)
:action_if(in_progress, "a", "abort", actions.abort)
Expand Down
8 changes: 3 additions & 5 deletions tests/specs/neogit/lib/git/index_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@ local function run_with_hunk(hunk, from, to, reverse)
local header_matches =
vim.fn.matchlist(lines[1], "@@ -\\(\\d\\+\\),\\(\\d\\+\\) +\\(\\d\\+\\),\\(\\d\\+\\) @@")
return generate_patch_from_selection({
name = "test.txt",
absolute_path = "test.txt",
diff = { lines = lines },
}, {
first = 1,
last = #lines,
index_from = header_matches[2],
index_len = header_matches[3],
diff_from = diff_from,
diff_to = #lines,
}, diff_from + from, diff_from + to, reverse)
lines = vim.list_slice(lines, 2),
file = "test.txt",
}, { from = from, to = to, reverse = reverse })
end

describe("patch creation", function()
Expand Down
3 changes: 3 additions & 0 deletions tests/specs/neogit/lib/git/log_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ describe("lib.git.log.parse", function()
index_from = 692,
index_len = 33,
length = 40,
file = "lua/neogit/status.lua",
line = "@@ -692,33 +692,28 @@ end",
lines = {
" ---@param first_line number",
Expand Down Expand Up @@ -149,6 +150,7 @@ describe("lib.git.log.parse", function()
index_from = 734,
index_len = 14,
length = 15,
file = "lua/neogit/status.lua",
line = "@@ -734,14 +729,10 @@ function M.get_item_hunks(item, first_line, last_line, partial)",
lines = {
" setmetatable(o, o)",
Expand Down Expand Up @@ -290,6 +292,7 @@ describe("lib.git.log.parse", function()
index_len = 7,
length = 9,
line = "@@ -1,7 +1,9 @@",
file = "LICENSE",
lines = {
" MIT License",
" ",
Expand Down

0 comments on commit 9dc5807

Please sign in to comment.