diff --git a/lua/r/send.lua b/lua/r/send.lua index 50bfa1f1..2216b1e1 100644 --- a/lua/r/send.lua +++ b/lua/r/send.lua @@ -645,97 +645,95 @@ M.line = function(m) end end --- Function to check if a string ends with a specific suffix ----@param str string ----@param suffix string ----@return boolean -local function ends_with(str, suffix) return str:sub(-#suffix) == suffix end - -local function trim_lines(array) - local result = {} -- Create a new table to store the trimmed lines - - for i = 1, #array do - local line = array[i] - local trimmedLine = line:match("^%s*(.-)%s*$") -- Remove leading and trailing whitespace - table.insert(result, trimmedLine) -- Add the trimmed line to the result table - end - - return result -end - --- Remove the <-, |>/%>% or + from the text ----@param array string[] ----@return string[] -local function sanatize_text(array) - local firstString = array[1] - -- Remove "<-" and everything before it from the first string - local modifiedFirstString = firstString:gsub(".*<%-%s*", "") - array[1] = modifiedFirstString +--- Send the above chain of piped commands +M.chain = function() + local bufnr = create_r_buffer() + if not bufnr then return end - local lastIndex = #array - local lastString = array[lastIndex] + local parser = vim.treesitter.get_parser(bufnr, "r") + if not parser then return end - -- Check if the last string ends with either "|>" or "%>%" - local modifiedString = - lastString:gsub("|>[%s]*$", ""):gsub("%%>%%[%s]*$", ""):gsub("%+[%s]*$", "") - array[lastIndex] = modifiedString + local tree = parser:parse()[1] + if not tree then return end - return array -end + local root = tree:root() + local query = vim.treesitter.query.parse( + "r", + [[ + (_ + (binary_operator + lhs: (_) + operator: ([("|>") ("<-") ("+") ("special")]) + rhs: (call) + ) @pipeline_no_assign + (#not-has-parent? @pipeline_no_assign binary_operator) + ) + + (_ + ; Handle when the pipeline is assignment to a variable + (binary_operator + lhs: (identifier) + rhs: (binary_operator + lhs: (_) + operator: ([("|>") ("+") ("special")]) + rhs: (call) + ) @pipeline_with_assign + ) + ) + ]] + ) ---- Check if string ends in one of specific pre-defined patterns ----@param str string ----@return boolean -function ends_with(str) - return string.match(str, "[|%%]%>%%?[%s]*$") ~= nil - or string.match(str, "%+[%s]*$") ~= nil - or string.match(str, "%([%s]*$") ~= nil -end + local cursor_row = vim.api.nvim_win_get_cursor(0)[1] - 1 + local pipe_block_node ---- Return the line where piped chain begins ----@param arr string[] ----@return number -local function chain_start_at(arr) - for i = 1, #arr do - if ends_with(arr[i]) then return i end + for _, node in query:iter_captures(root, bufnr, 0, -1) do + local start_row, _, end_row = node:range() + if cursor_row >= start_row and cursor_row <= end_row then + pipe_block_node = node + break + end end - return #arr -end - ---- Send the above chain of piped commands -M.chain = function() - -- Get the current line, the start and end line of the paragraph - local current_line = vim.api.nvim_win_get_cursor(0)[1] - local startLine = vim.fn.search("^$", "bnW") -- Search for previous empty line - local endLine = vim.fn.search("^$", "nW") - 1 -- Search for next empty line and adjust for exclusive range - - -- Get the paragraph lines - local paragraphLines = vim.api.nvim_buf_get_lines(0, startLine, endLine, false) - paragraphLines = trim_lines(paragraphLines) - - -- Get the relative line number within the paragraph - local relativeLineNumber = current_line - startLine - - paragraphLines = trim_lines(paragraphLines) - - local extractedLines = {} - for i = 1, relativeLineNumber do - table.insert(extractedLines, paragraphLines[i]) + if not pipe_block_node then + inform("The cursor is not inside a piped expression.") + return end - -- Find the starting line of the chain - local lineChainStartAt = chain_start_at(extractedLines) + local call_query = vim.treesitter.query.parse( + "r", + [[ + (_ + (binary_operator + lhs: (_) + operator: (["|>" "+" "special"] @operator) + rhs: (call) @call + (#not-has-ancestor? @call call) ;; Ensure the rhs is not inside another call + ) + ) + ]] + ) - local chain = {} + local sibling = nil + local visited = false - for i = lineChainStartAt, relativeLineNumber do - table.insert(chain, extractedLines[i]) + for id, node, _ in call_query:iter_captures(pipe_block_node, bufnr, 0, -1) do + local capture_name = call_query.captures[id] + local start_row, _, end_row = node:range() + + if + capture_name == "operator" and visited + or cursor_row == pipe_block_node:range() + then + sibling = node:prev_sibling() + break + elseif capture_name == "call" then + if cursor_row >= start_row and cursor_row <= end_row then visited = true end + end end - chain = sanatize_text(chain) + local captured_node = sibling or pipe_block_node - M.source_lines(chain, nil) + M.source_lines({ vim.treesitter.get_node_text(captured_node, bufnr) }, nil) end --- Retrieves R function nodes from a given buffer using TreeSitter.