Skip to content

Commit

Permalink
refactor(send.lua): improve piped command handling with TreeSitter
Browse files Browse the repository at this point in the history
Replace manual string manipulation and pattern matching with
TreeSitter queries to identify and process piped command chains.
This enhances the accuracy and maintainability of the code by
leveraging TreeSitter's parsing capabilities, reducing reliance
on custom string functions.
  • Loading branch information
PMassicotte committed Feb 23, 2025
1 parent 3f98d95 commit ad23ef9
Showing 1 changed file with 75 additions and 77 deletions.
152 changes: 75 additions & 77 deletions lua/r/send.lua
Original file line number Diff line number Diff line change
Expand Up @@ -645,97 +645,95 @@ M.line = function(m)
end
end

-- Function to check if a string ends with a specific suffix
---@param str string
---@param suffix string
---@return boolean
local function ends_with(str, suffix) return str:sub(-#suffix) == suffix end

local function trim_lines(array)
local result = {} -- Create a new table to store the trimmed lines

for i = 1, #array do
local line = array[i]
local trimmedLine = line:match("^%s*(.-)%s*$") -- Remove leading and trailing whitespace
table.insert(result, trimmedLine) -- Add the trimmed line to the result table
end

return result
end

-- Remove the <-, |>/%>% or + from the text
---@param array string[]
---@return string[]
local function sanatize_text(array)
local firstString = array[1]
-- Remove "<-" and everything before it from the first string
local modifiedFirstString = firstString:gsub(".*<%-%s*", "")
array[1] = modifiedFirstString
--- Send the above chain of piped commands
M.chain = function()
local bufnr = create_r_buffer()
if not bufnr then return end

local lastIndex = #array
local lastString = array[lastIndex]
local parser = vim.treesitter.get_parser(bufnr, "r")
if not parser then return end

-- Check if the last string ends with either "|>" or "%>%"
local modifiedString =
lastString:gsub("|>[%s]*$", ""):gsub("%%>%%[%s]*$", ""):gsub("%+[%s]*$", "")
array[lastIndex] = modifiedString
local tree = parser:parse()[1]
if not tree then return end

return array
end
local root = tree:root()
local query = vim.treesitter.query.parse(
"r",
[[
(_
(binary_operator
lhs: (_)
operator: ([("|>") ("<-") ("+") ("special")])
rhs: (call)
) @pipeline_no_assign
(#not-has-parent? @pipeline_no_assign binary_operator)
)
(_
; Handle when the pipeline is assignment to a variable
(binary_operator
lhs: (identifier)
rhs: (binary_operator
lhs: (_)
operator: ([("|>") ("+") ("special")])
rhs: (call)
) @pipeline_with_assign
)
)
]]
)

--- Check if string ends in one of specific pre-defined patterns
---@param str string
---@return boolean
function ends_with(str)
return string.match(str, "[|%%]%>%%?[%s]*$") ~= nil
or string.match(str, "%+[%s]*$") ~= nil
or string.match(str, "%([%s]*$") ~= nil
end
local cursor_row = vim.api.nvim_win_get_cursor(0)[1] - 1
local pipe_block_node

--- Return the line where piped chain begins
---@param arr string[]
---@return number
local function chain_start_at(arr)
for i = 1, #arr do
if ends_with(arr[i]) then return i end
for _, node in query:iter_captures(root, bufnr, 0, -1) do
local start_row, _, end_row = node:range()
if cursor_row >= start_row and cursor_row <= end_row then
pipe_block_node = node
break
end
end

return #arr
end

--- Send the above chain of piped commands
M.chain = function()
-- Get the current line, the start and end line of the paragraph
local current_line = vim.api.nvim_win_get_cursor(0)[1]
local startLine = vim.fn.search("^$", "bnW") -- Search for previous empty line
local endLine = vim.fn.search("^$", "nW") - 1 -- Search for next empty line and adjust for exclusive range

-- Get the paragraph lines
local paragraphLines = vim.api.nvim_buf_get_lines(0, startLine, endLine, false)
paragraphLines = trim_lines(paragraphLines)

-- Get the relative line number within the paragraph
local relativeLineNumber = current_line - startLine

paragraphLines = trim_lines(paragraphLines)

local extractedLines = {}
for i = 1, relativeLineNumber do
table.insert(extractedLines, paragraphLines[i])
if not pipe_block_node then
inform("The cursor is not inside a piped expression.")
return
end

-- Find the starting line of the chain
local lineChainStartAt = chain_start_at(extractedLines)
local call_query = vim.treesitter.query.parse(
"r",
[[
(_
(binary_operator
lhs: (_)
operator: (["|>" "+" "special"] @operator)
rhs: (call) @call
(#not-has-ancestor? @call call) ;; Ensure the rhs is not inside another call
)
)
]]
)

local chain = {}
local sibling = nil
local visited = false

for i = lineChainStartAt, relativeLineNumber do
table.insert(chain, extractedLines[i])
for id, node, _ in call_query:iter_captures(pipe_block_node, bufnr, 0, -1) do
local capture_name = call_query.captures[id]
local start_row, _, end_row = node:range()

if
capture_name == "operator" and visited
or cursor_row == pipe_block_node:range()
then
sibling = node:prev_sibling()
break
elseif capture_name == "call" then
if cursor_row >= start_row and cursor_row <= end_row then visited = true end
end
end

chain = sanatize_text(chain)
local captured_node = sibling or pipe_block_node

M.source_lines(chain, nil)
M.source_lines({ vim.treesitter.get_node_text(captured_node, bufnr) }, nil)
end

--- Retrieves R function nodes from a given buffer using TreeSitter.
Expand Down

0 comments on commit ad23ef9

Please sign in to comment.