Skip to content

Commit f245cca

Browse files
authored
fix(injected): handle inline injections (#251)
1 parent 7396fc0 commit f245cca

19 files changed

+357
-55
lines changed

.github/workflows/automation_remove_question_label_on_comment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
# issues in my "needs triage" filter.
99
remove_question:
1010
runs-on: ubuntu-latest
11+
if: github.event.sender.login != 'stevearc'
1112
steps:
1213
- uses: actions/checkout@v2
1314
- uses: actions-ecosystem/action-remove-labels@v1

lua/conform/formatters/injected.lua

+96-22
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,31 @@ local function apply_indent(lines, indentation)
6060
end
6161
end
6262

63+
---@class LangRange
64+
---@field [1] string language
65+
---@field [2] integer start lnum
66+
---@field [3] integer start col
67+
---@field [4] integer end lnum
68+
---@field [5] integer end col
69+
70+
---@param ranges LangRange[]
71+
---@param range LangRange
72+
local function accum_range(ranges, range)
73+
local last_range = ranges[#ranges]
74+
if last_range then
75+
if last_range[1] == range[1] and last_range[4] == range[2] and last_range[5] == range[3] then
76+
last_range[4] = range[4]
77+
last_range[5] = range[5]
78+
return
79+
end
80+
end
81+
table.insert(ranges, range)
82+
end
83+
6384
---@class (exact) conform.InjectedFormatterOptions
6485
---@field ignore_errors boolean
86+
---@field lang_to_ext table<string, string>
87+
---@field lang_to_formatters table<string, conform.FiletypeFormatter>
6588

6689
---@type conform.FileLuaFormatterConfig
6790
return {
@@ -72,6 +95,26 @@ return {
7295
options = {
7396
-- Set to true to ignore errors
7497
ignore_errors = false,
98+
-- Map of treesitter language to file extension
99+
-- A temporary file name with this extension will be generated during formatting
100+
-- because some formatters care about the filename.
101+
lang_to_ext = {
102+
bash = "sh",
103+
c_sharp = "cs",
104+
elixir = "exs",
105+
javascript = "js",
106+
julia = "jl",
107+
latex = "tex",
108+
markdown = "md",
109+
python = "py",
110+
ruby = "rb",
111+
rust = "rs",
112+
teal = "tl",
113+
typescript = "ts",
114+
},
115+
-- Map of treesitter language to formatters to use
116+
-- (defaults to the value from formatters_by_ft)
117+
lang_to_formatters = {},
75118
},
76119
condition = function(self, ctx)
77120
local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
@@ -93,12 +136,20 @@ return {
93136
end
94137
---@type conform.InjectedFormatterOptions
95138
local options = self.options
139+
140+
---@param lang string
141+
---@return nil|conform.FiletypeFormatter
142+
local function get_formatters(lang)
143+
return options.lang_to_formatters[lang] or conform.formatters_by_ft[lang]
144+
end
145+
96146
--- Disable diagnostic to pass the typecheck github action
97147
--- This is available on nightly, but not on stable
98148
--- Stable doesn't have any parameters, so it's safe to always pass `false`
99149
---@diagnostic disable-next-line: redundant-parameter
100150
parser:parse(false)
101151
local root_lang = parser:lang()
152+
---@type LangRange[]
102153
local regions = {}
103154

104155
for _, tree in pairs(parser:trees()) do
@@ -124,26 +175,26 @@ return {
124175
do
125176
---@diagnostic disable-next-line: invisible
126177
local lang, combined, ranges = parser:_get_injection(match, metadata)
127-
local has_formatters = conform.formatters_by_ft[lang] ~= nil
128-
if lang and has_formatters and not combined and #ranges > 0 and lang ~= root_lang then
129-
local start_lnum
130-
local end_lnum
131-
-- Merge all of the ranges into a single range
178+
if
179+
lang
180+
and get_formatters(lang) ~= nil
181+
and not combined
182+
and #ranges > 0
183+
and lang ~= root_lang
184+
then
132185
for _, range in ipairs(ranges) do
133-
if not start_lnum or start_lnum > range[1] + 1 then
134-
start_lnum = range[1] + 1
135-
end
136-
if not end_lnum or end_lnum < range[4] then
137-
end_lnum = range[4]
138-
end
139-
end
140-
if in_range(ctx.range, start_lnum, end_lnum) then
141-
table.insert(regions, { lang, start_lnum, end_lnum })
186+
accum_range(regions, { lang, range[1] + 1, range[2], range[4] + 1, range[5] })
142187
end
143188
end
144189
end
145190
end
146191

192+
if ctx.range then
193+
regions = vim.tbl_filter(function(region)
194+
return in_range(ctx.range, region[2], region[4])
195+
end, regions)
196+
end
197+
147198
-- Sort from largest start_lnum to smallest
148199
table.sort(regions, function(a, b)
149200
return a[2] > b[2]
@@ -171,7 +222,11 @@ return {
171222

172223
local formatted_lines = vim.deepcopy(lines)
173224
for _, replacement in ipairs(replacements) do
174-
local start_lnum, end_lnum, new_lines = unpack(replacement)
225+
local start_lnum, start_col, end_lnum, end_col, new_lines = unpack(replacement)
226+
local prefix = formatted_lines[start_lnum]:sub(1, start_col)
227+
local suffix = formatted_lines[end_lnum]:sub(end_col + 1)
228+
new_lines[1] = prefix .. new_lines[1]
229+
new_lines[#new_lines] = new_lines[#new_lines] .. suffix
175230
for _ = start_lnum, end_lnum do
176231
table.remove(formatted_lines, start_lnum)
177232
end
@@ -184,12 +239,20 @@ return {
184239

185240
local num_format = 0
186241
local tmp_bufs = {}
187-
local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines)
242+
local formatter_cb = function(err, idx, region, input_lines, new_lines)
188243
if err then
189244
format_error = errors.coalesce(format_error, err)
190245
replacements[idx] = err
191246
else
192-
replacements[idx] = { start_lnum, end_lnum, new_lines }
247+
-- If the original lines started/ended with a newline, preserve that newline.
248+
-- Many formatters will trim them, but they're important for the document structure.
249+
if input_lines[1] == "" and new_lines[1] ~= "" then
250+
table.insert(new_lines, 1, "")
251+
end
252+
if input_lines[#input_lines] == "" and new_lines[#new_lines] ~= "" then
253+
table.insert(new_lines, "")
254+
end
255+
replacements[idx] = { region[2], region[3], region[4], region[5], new_lines }
193256
end
194257
num_format = num_format - 1
195258
if num_format == 0 then
@@ -200,14 +263,22 @@ return {
200263
end
201264
end
202265
local last_start_lnum = #lines + 1
203-
for _, region in ipairs(regions) do
204-
local lang, start_lnum, end_lnum = unpack(region)
266+
for i, region in ipairs(regions) do
267+
local lang = region[1]
268+
local start_lnum = region[2]
269+
local start_col = region[3]
270+
local end_lnum = region[4]
271+
local end_col = region[5]
205272
-- Ignore regions that overlap (contain) other regions
206273
if end_lnum < last_start_lnum then
207274
num_format = num_format + 1
208275
last_start_lnum = start_lnum
209276
local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
210-
local ft_formatters = conform.formatters_by_ft[lang]
277+
input_lines[#input_lines] = input_lines[#input_lines]:sub(1, end_col)
278+
if start_col > 0 then
279+
input_lines[1] = input_lines[1]:sub(start_col + 1)
280+
end
281+
local ft_formatters = assert(get_formatters(lang))
211282
---@type string[]
212283
local formatter_names
213284
if type(ft_formatters) == "function" then
@@ -226,15 +297,18 @@ return {
226297
-- extension to determine a run mode (see https://github.com/stevearc/conform.nvim/issues/194)
227298
-- This is using the language name as the file extension, but that is a reasonable
228299
-- approximation for now. We can add special cases as the need arises.
229-
local buf = vim.fn.bufadd(string.format("%s.%s", vim.api.nvim_buf_get_name(ctx.buf), lang))
300+
local extension = options.lang_to_ext[lang] or lang
301+
local buf =
302+
vim.fn.bufadd(string.format("%s.%d.%s", vim.api.nvim_buf_get_name(ctx.buf), i, extension))
230303
-- Actually load the buffer to set the buffer context which is required by some formatters such as `filetype`
231304
vim.fn.bufload(buf)
232305
tmp_bufs[buf] = true
233306
local format_opts = { async = true, bufnr = buf, quiet = true }
234307
conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
308+
log.trace("Injected %s:%d:%d formatted lines %s", lang, start_lnum, end_lnum, new_lines)
235309
-- Preserve indentation in case the code block is indented
236310
apply_indent(new_lines, indent)
237-
formatter_cb(err, idx, start_lnum, end_lnum, new_lines)
311+
vim.schedule_wrap(formatter_cb)(err, idx, region, input_lines, new_lines)
238312
end)
239313
end
240314
end

lua/conform/fs.lua

+20
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,24 @@ M.join = function(...)
1515
return table.concat({ ... }, M.sep)
1616
end
1717

18+
---@param filepath string
19+
---@return boolean
20+
M.exists = function(filepath)
21+
local stat = uv.fs_stat(filepath)
22+
return stat ~= nil and stat.type ~= nil
23+
end
24+
25+
---@param filepath string
26+
---@return string?
27+
M.read_file = function(filepath)
28+
if not M.exists(filepath) then
29+
return nil
30+
end
31+
local fd = assert(uv.fs_open(filepath, "r", 420)) -- 0644
32+
local stat = assert(uv.fs_fstat(fd))
33+
local content = uv.fs_read(fd, stat.size)
34+
uv.fs_close(fd)
35+
return content
36+
end
37+
1838
return M

lua/conform/init.lua

+8-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ local M = {}
3636
---@field inherit? boolean
3737
---@field command? string|fun(self: conform.FormatterConfig, ctx: conform.Context): string
3838
---@field prepend_args? string|string[]|fun(self: conform.FormatterConfig, ctx: conform.Context): string|string[]
39+
---@field format? fun(self: conform.LuaFormatterConfig, ctx: conform.Context, lines: string[], callback: fun(err: nil|string, new_lines: nil|string[])) Mutually exclusive with command
3940
---@field options? table
4041

4142
---@class (exact) conform.FormatterMeta
@@ -569,6 +570,12 @@ M.get_formatter_config = function(formatter, bufnr)
569570
if type(override) == "function" then
570571
override = override(bufnr)
571572
end
573+
if override and override.command and override.format then
574+
local msg =
575+
string.format("Formatter '%s' cannot define both 'command' and 'format' function", formatter)
576+
vim.notify_once(msg, vim.log.levels.ERROR)
577+
return nil
578+
end
572579

573580
---@type nil|conform.FormatterConfig
574581
local config = override
@@ -581,7 +588,7 @@ M.get_formatter_config = function(formatter, bufnr)
581588
config = mod_config
582589
end
583590
elseif override then
584-
if override.command then
591+
if override.command or override.format then
585592
config = override
586593
else
587594
local msg = string.format(

lua/conform/runner.lua

+1
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ local function run_formatter(bufnr, formatter, config, ctx, input_lines, opts, c
342342
end
343343
log.debug("%s exited with code %d", formatter.name, code)
344344
log.trace("Output lines: %s", output)
345+
log.trace("%s stderr: %s", formatter.name, stderr)
345346
callback(nil, output)
346347
else
347348
log.info("%s exited with code %d", formatter.name, code)

run_tests.sh

+7-1
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@ else
1414
(cd "$PLUGINS/plenary.nvim" && git pull)
1515
fi
1616

17+
if [ ! -e "$PLUGINS/nvim-treesitter" ]; then
18+
git clone --depth=1 https://github.com/nvim-treesitter/nvim-treesitter.git "$PLUGINS/nvim-treesitter"
19+
else
20+
(cd "$PLUGINS/nvim-treesitter" && git pull)
21+
fi
22+
1723
XDG_CONFIG_HOME=".testenv/config" \
1824
XDG_DATA_HOME=".testenv/data" \
1925
XDG_STATE_HOME=".testenv/state" \
2026
XDG_RUNTIME_DIR=".testenv/run" \
2127
XDG_CACHE_HOME=".testenv/cache" \
2228
nvim --headless -u tests/minimal_init.lua \
23-
-c "PlenaryBustedDirectory ${1-tests} { minimal_init = './tests/minimal_init.lua' }"
29+
-c "RunTests ${1-tests}"
2430
echo "Success"

tests/fake_formatter.sh

+17-9
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,24 @@
22

33
set -e
44

5-
if [ -e "tests/fake_formatter_output" ]; then
6-
cat tests/fake_formatter_output
7-
else
8-
cat
5+
CODE=0
6+
if [ "$1" = "--fail" ]; then
7+
shift
8+
echo "failure" >&2
9+
CODE=1
10+
fi
11+
if [ "$1" = "--timeout" ]; then
12+
shift
13+
echo "timeout" >&2
14+
sleep 4
915
fi
1016

11-
if [ "$1" = "--fail" ]; then
12-
echo "failure" >&2
13-
exit 1
14-
elif [ "$1" = "--timeout" ]; then
15-
sleep 4
17+
output_file="$1"
18+
19+
if [ -n "$output_file" ] && [ -e "$output_file" ]; then
20+
cat "$output_file"
21+
else
22+
cat
1623
fi
1724

25+
exit $CODE

tests/injected/block_quote.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
text
2+
3+
> ```lua
4+
> local foo = 'bar'
5+
> ```
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
text
2+
3+
> ```lua
4+
> |local foo = 'bar'|
5+
> ```

tests/injected/combined_injections.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
text
2+
3+
<!-- comment -->
4+
5+
```lua
6+
local foo = 'bar'
7+
```
8+
9+
10+
<!-- comment -->
11+
12+
```lua
13+
local foo = 'bar'
14+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
text
2+
3+
<!-- comment -->
4+
5+
```lua
6+
|local foo = 'bar'|
7+
```
8+
9+
10+
<!-- comment -->
11+
12+
```lua
13+
|local foo = 'bar'|
14+
```

tests/injected/inline.ts

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
foo.innerHTML = `<div> hello </div>`;
2+
3+
bar.innerHTML = `
4+
<div> world </div>
5+
`;

tests/injected/inline.ts.formatted

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
foo.innerHTML = `|<div> hello </div>|`;
2+
3+
bar.innerHTML = `
4+
|<div> world </div>|
5+
`;

tests/injected/simple.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
text
2+
3+
```lua
4+
local foo = 'bar'
5+
```

tests/injected/simple.md.formatted

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
text
2+
3+
```lua
4+
|local foo = 'bar'|
5+
```

0 commit comments

Comments
 (0)