Skip to content

Commit 7b9728e

Browse files
committed
Allow specifying the location for the branch number in the branch name template.
Previously the branch number was always at the end after `/`. This PR generalizes it and allows it to be present anywhere in the template under the placeholder "$ID". Fixes #57.
1 parent d353d74 commit 7b9728e

File tree

3 files changed

+138
-26
lines changed

3 files changed

+138
-26
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ These arguments can be used with any subcommand:
250250
- `-T, --target`: Remote target branch (default: "main")
251251
- `--hyperlinks/--no-hyperlinks`: Enable/disable hyperlink support (default: enabled)
252252
- `-V, --verbose`: Enable verbose output from Git subcommands (default: false)
253-
- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack")
253+
- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack"). The following variables are supported:
254+
- `$USERNAME`: The username of the current user
255+
- `$BRANCH`: The current branch name
256+
- `$ID`: The location for the ID of the branch. The ID is determined by the order of creation of the branches. If `$ID` is not found in the template, the template will be appended with `/$ID`.
254257
255258
### Subcommands
256259

src/stack_pr/cli.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def pprint(self, links: bool):
286286
s = b(self.commit.commit_id()[:8])
287287
pr_string = None
288288
if self.has_pr():
289-
pr_string = blue("#" + self.pr.split("/")[-1])
289+
pr_string = blue("#" + last(self.pr))
290290
else:
291291
pr_string = red("no PR")
292292
branch_string = None
@@ -384,16 +384,6 @@ def split_header(s: str) -> List[CommitHeader]:
384384
return [CommitHeader(h) for h in s.split("\0")[:-1]]
385385

386386

387-
def is_valid_ref(ref: str, branch_name_template: str) -> bool:
388-
ref = ref.strip("'")
389-
390-
branch_name_base = get_branch_name_base(branch_name_template)
391-
splits = ref.rsplit("/", 1)
392-
if len(splits) < 2:
393-
return False
394-
return splits[-2].endswith(branch_name_base) and splits[-1].isnumeric()
395-
396-
397387
def last(ref: str, sep: str = "/") -> str:
398388
return ref.rsplit(sep, 1)[-1]
399389

@@ -559,6 +549,13 @@ def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) ->
559549
return True
560550

561551

552+
def fix_branch_name_template(branch_name_template: str):
553+
if "$ID" not in branch_name_template:
554+
return f"{branch_name_template}/$ID"
555+
556+
return branch_name_template
557+
558+
562559
@cache
563560
def get_branch_name_base(branch_name_template: str):
564561
username = get_gh_username()
@@ -568,31 +565,54 @@ def get_branch_name_base(branch_name_template: str):
568565
return branch_name_base
569566

570567

568+
def get_branch_id(branch_name_template: str, branch_name: str):
569+
branch_name_base = get_branch_name_base(branch_name_template)
570+
pattern = branch_name_base.replace(r"$ID", r"(\d+)")
571+
match = re.search(pattern, branch_name)
572+
if match:
573+
return match.group(1)
574+
return None
575+
576+
577+
def generate_branch_name(branch_name_template: str, branch_id: int):
578+
branch_name_base = get_branch_name_base(branch_name_template)
579+
branch_name = branch_name_base.replace(r"$ID", branch_id)
580+
return branch_name
581+
582+
583+
def get_taken_branch_ids(refs: List[str], branch_name_template: str) -> List[int]:
584+
branch_ids = list(get_branch_id(branch_name_template, ref) for ref in refs)
585+
branch_ids = [int(branch_id) for branch_id in branch_ids if branch_id is not None]
586+
return branch_ids
587+
588+
589+
def generate_available_branch_name(refs: List[str], branch_name_template: str) -> str:
590+
branch_ids = get_taken_branch_ids(refs, branch_name_template)
591+
max_ref_num = max(branch_ids) if branch_ids else 0
592+
new_branch_id = max_ref_num + 1
593+
return generate_branch_name(branch_name_template, str(new_branch_id))
594+
595+
571596
def get_available_branch_name(remote: str, branch_name_template: str) -> str:
572597
branch_name_base = get_branch_name_base(branch_name_template)
573598

599+
git_command_branch_template = branch_name_base.replace(r"$ID", "*")
574600
refs = get_command_output(
575601
[
576602
"git",
577603
"for-each-ref",
578-
f"refs/remotes/{remote}/{branch_name_base}",
604+
f"refs/remotes/{remote}/{git_command_branch_template}",
579605
"--format='%(refname)'",
580606
]
581607
).split()
582608

583-
def check_ref(ref):
584-
return is_valid_ref(ref, branch_name_base)
585-
586-
refs = list(filter(check_ref, refs))
587-
max_ref_num = max(int(last(ref.strip("'"))) for ref in refs) if refs else 0
588-
new_branch_id = max_ref_num + 1
609+
refs = list([ref.strip("'") for ref in refs])
610+
return generate_available_branch_name(refs, branch_name_template)
589611

590-
return f"{branch_name_base}/{new_branch_id}"
591612

592-
593-
def get_next_available_branch_name(name: str) -> str:
594-
base, id = name.rsplit("/", 1)
595-
return f"{base}/{int(id) + 1}"
613+
def get_next_available_branch_name(branch_name_template: str, name: str) -> str:
614+
id = get_branch_id(branch_name_template, name)
615+
return generate_branch_name(branch_name_template, str(int(id) + 1))
596616

597617

598618
def set_head_branches(
@@ -604,7 +624,9 @@ def set_head_branches(
604624
available_name = get_available_branch_name(remote, branch_name_template)
605625
for e in filter(lambda e: not e.has_head(), st):
606626
e.head = available_name
607-
available_name = get_next_available_branch_name(available_name)
627+
available_name = get_next_available_branch_name(
628+
branch_name_template, available_name
629+
)
608630

609631

610632
def init_local_branches(
@@ -1359,6 +1381,8 @@ def main():
13591381
parser.print_help()
13601382
return
13611383

1384+
# Make sure "$ID" is present in the branch name template and append it if not
1385+
args.branch_name_template = fix_branch_name_template(args.branch_name_template)
13621386
common_args = CommonArgs.from_args(args)
13631387

13641388
if common_args.verbose:
@@ -1377,7 +1401,9 @@ def main():
13771401

13781402
if args.command in ["submit", "export"]:
13791403
if args.stash:
1380-
run_shell_command(["git", "stash", "save"], quiet=not common_args.verbose)
1404+
run_shell_command(
1405+
["git", "stash", "save"], quiet=not common_args.verbose
1406+
)
13811407
command_submit(
13821408
common_args,
13831409
args.draft,

tests/test_misc.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import sys
2+
from pathlib import Path
3+
4+
sys.path.append(str(Path(__file__).parent.parent / "src"))
5+
6+
from stack_pr.cli import (
7+
get_branch_id,
8+
generate_branch_name,
9+
get_taken_branch_ids,
10+
get_gh_username,
11+
generate_available_branch_name,
12+
)
13+
14+
import pytest
15+
16+
17+
@pytest.fixture(scope="module")
18+
def username():
19+
return get_gh_username()
20+
21+
22+
@pytest.mark.parametrize(
23+
"template,branch_name,expected",
24+
[
25+
("feature-$ID-desc", "feature-123-desc", "123"),
26+
("$USERNAME/stack/$ID", "{username}/stack/99", "99"),
27+
("$USERNAME/stack/$ID", "refs/remote/origin/{username}/stack/99", "99"),
28+
],
29+
)
30+
def test_get_branch_id(username, template, branch_name, expected):
31+
branch_name = branch_name.format(username=username)
32+
assert get_branch_id(template, branch_name) == expected
33+
34+
35+
@pytest.mark.parametrize(
36+
"template,branch_name",
37+
[
38+
("feature/$ID/desc", "feature/abc/desc"),
39+
("feature/$ID/desc", "wrong/format"),
40+
("$USERNAME/stack/$ID", "{username}/main/99"),
41+
],
42+
)
43+
def test_get_branch_id_no_match(username, template, branch_name):
44+
branch_name = branch_name.format(username=username)
45+
assert get_branch_id(template, branch_name) is None
46+
47+
48+
def test_generate_branch_name():
49+
template = "feature/$ID/description"
50+
assert generate_branch_name(template, "123") == "feature/123/description"
51+
52+
53+
def test_get_taken_branch_ids():
54+
template = "User/stack/$ID"
55+
refs = [
56+
"refs/remotes/origin/User/stack/104",
57+
"refs/remotes/origin/User/stack/105",
58+
"refs/remotes/origin/User/stack/134",
59+
]
60+
assert get_taken_branch_ids(refs, template) == [104, 105, 134]
61+
refs = ["User/stack/104", "User/stack/105", "User/stack/134"]
62+
assert get_taken_branch_ids(refs, template) == [104, 105, 134]
63+
refs = ["User/stack/104", "AAAA/stack/105", "User/stack/134", "User/stack/bbb"]
64+
assert get_taken_branch_ids(refs, template) == [104, 134]
65+
66+
67+
def test_generate_available_branch_name():
68+
template = "User/stack/$ID"
69+
refs = [
70+
"refs/remotes/origin/User/stack/104",
71+
"refs/remotes/origin/User/stack/105",
72+
"refs/remotes/origin/User/stack/134",
73+
]
74+
assert generate_available_branch_name(refs, template) == "User/stack/135"
75+
refs = []
76+
assert generate_available_branch_name(refs, template) == "User/stack/1"
77+
template = "User-stack-$ID"
78+
refs = [
79+
"refs/remotes/origin/User-stack-104",
80+
"refs/remotes/origin/User-stack-105",
81+
"refs/remotes/origin/User-stack-134",
82+
]
83+
assert generate_available_branch_name(refs, template) == "User-stack-135"

0 commit comments

Comments
 (0)