Skip to content

Commit dda58b1

Browse files
authored
Allow specifying the location for the branch number in the branch name template. (#73)
Previously the branch number was always at the end after `/`. This PR generalizes it and allows it to be present anywhere in the template under the placeholder "$ID". Fixes #57.
1 parent b7927f3 commit dda58b1

File tree

3 files changed

+135
-25
lines changed

3 files changed

+135
-25
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ These arguments can be used with any subcommand:
250250
- `-T, --target`: Remote target branch (default: "main")
251251
- `--hyperlinks/--no-hyperlinks`: Enable/disable hyperlink support (default: enabled)
252252
- `-V, --verbose`: Enable verbose output from Git subcommands (default: false)
253-
- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack")
253+
- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack"). The following variables are supported:
254+
- `$USERNAME`: The username of the current user
255+
- `$BRANCH`: The current branch name
256+
- `$ID`: The location for the ID of the branch. The ID is determined by the order of creation of the branches. If `$ID` is not found in the template, the template will be appended with `/$ID`.
254257
255258
### Subcommands
256259

src/stack_pr/cli.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def pprint(self, links: bool):
286286
s = b(self.commit.commit_id()[:8])
287287
pr_string = None
288288
if self.has_pr():
289-
pr_string = blue("#" + self.pr.split("/")[-1])
289+
pr_string = blue("#" + last(self.pr))
290290
else:
291291
pr_string = red("no PR")
292292
branch_string = None
@@ -384,16 +384,6 @@ def split_header(s: str) -> List[CommitHeader]:
384384
return [CommitHeader(h) for h in s.split("\0")[:-1]]
385385

386386

387-
def is_valid_ref(ref: str, branch_name_template: str) -> bool:
388-
ref = ref.strip("'")
389-
390-
branch_name_base = get_branch_name_base(branch_name_template)
391-
splits = ref.rsplit("/", 1)
392-
if len(splits) < 2:
393-
return False
394-
return splits[-2].endswith(branch_name_base) and splits[-1].isnumeric()
395-
396-
397387
def last(ref: str, sep: str = "/") -> str:
398388
return ref.rsplit(sep, 1)[-1]
399389

@@ -559,6 +549,13 @@ def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) ->
559549
return True
560550

561551

552+
def fix_branch_name_template(branch_name_template: str):
553+
if "$ID" not in branch_name_template:
554+
return f"{branch_name_template}/$ID"
555+
556+
return branch_name_template
557+
558+
562559
@cache
563560
def get_branch_name_base(branch_name_template: str):
564561
username = get_gh_username()
@@ -568,31 +565,54 @@ def get_branch_name_base(branch_name_template: str):
568565
return branch_name_base
569566

570567

568+
def get_branch_id(branch_name_template: str, branch_name: str):
569+
branch_name_base = get_branch_name_base(branch_name_template)
570+
pattern = branch_name_base.replace(r"$ID", r"(\d+)")
571+
match = re.search(pattern, branch_name)
572+
if match:
573+
return match.group(1)
574+
return None
575+
576+
577+
def generate_branch_name(branch_name_template: str, branch_id: int):
578+
branch_name_base = get_branch_name_base(branch_name_template)
579+
branch_name = branch_name_base.replace(r"$ID", branch_id)
580+
return branch_name
581+
582+
583+
def get_taken_branch_ids(refs: List[str], branch_name_template: str) -> List[int]:
584+
branch_ids = list(get_branch_id(branch_name_template, ref) for ref in refs)
585+
branch_ids = [int(branch_id) for branch_id in branch_ids if branch_id is not None]
586+
return branch_ids
587+
588+
589+
def generate_available_branch_name(refs: List[str], branch_name_template: str) -> str:
590+
branch_ids = get_taken_branch_ids(refs, branch_name_template)
591+
max_ref_num = max(branch_ids) if branch_ids else 0
592+
new_branch_id = max_ref_num + 1
593+
return generate_branch_name(branch_name_template, str(new_branch_id))
594+
595+
571596
def get_available_branch_name(remote: str, branch_name_template: str) -> str:
572597
branch_name_base = get_branch_name_base(branch_name_template)
573598

599+
git_command_branch_template = branch_name_base.replace(r"$ID", "*")
574600
refs = get_command_output(
575601
[
576602
"git",
577603
"for-each-ref",
578-
f"refs/remotes/{remote}/{branch_name_base}",
604+
f"refs/remotes/{remote}/{git_command_branch_template}",
579605
"--format='%(refname)'",
580606
]
581607
).split()
582608

583-
def check_ref(ref):
584-
return is_valid_ref(ref, branch_name_base)
585-
586-
refs = list(filter(check_ref, refs))
587-
max_ref_num = max(int(last(ref.strip("'"))) for ref in refs) if refs else 0
588-
new_branch_id = max_ref_num + 1
609+
refs = list([ref.strip("'") for ref in refs])
610+
return generate_available_branch_name(refs, branch_name_template)
589611

590-
return f"{branch_name_base}/{new_branch_id}"
591612

592-
593-
def get_next_available_branch_name(name: str) -> str:
594-
base, id = name.rsplit("/", 1)
595-
return f"{base}/{int(id) + 1}"
613+
def get_next_available_branch_name(branch_name_template: str, name: str) -> str:
614+
id = get_branch_id(branch_name_template, name)
615+
return generate_branch_name(branch_name_template, str(int(id) + 1))
596616

597617

598618
def set_head_branches(
@@ -604,7 +624,9 @@ def set_head_branches(
604624
available_name = get_available_branch_name(remote, branch_name_template)
605625
for e in filter(lambda e: not e.has_head(), st):
606626
e.head = available_name
607-
available_name = get_next_available_branch_name(available_name)
627+
available_name = get_next_available_branch_name(
628+
branch_name_template, available_name
629+
)
608630

609631

610632
def init_local_branches(
@@ -1359,6 +1381,8 @@ def main():
13591381
parser.print_help()
13601382
return
13611383

1384+
# Make sure "$ID" is present in the branch name template and append it if not
1385+
args.branch_name_template = fix_branch_name_template(args.branch_name_template)
13621386
common_args = CommonArgs.from_args(args)
13631387

13641388
if common_args.verbose:

tests/test_misc.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import sys
2+
from pathlib import Path
3+
4+
sys.path.append(str(Path(__file__).parent.parent / "src"))
5+
6+
from stack_pr.cli import (
7+
get_branch_id,
8+
generate_branch_name,
9+
get_taken_branch_ids,
10+
get_gh_username,
11+
generate_available_branch_name,
12+
)
13+
14+
import pytest
15+
16+
17+
@pytest.fixture(scope="module")
18+
def username():
19+
return get_gh_username()
20+
21+
22+
@pytest.mark.parametrize(
23+
"template,branch_name,expected",
24+
[
25+
("feature-$ID-desc", "feature-123-desc", "123"),
26+
("$USERNAME/stack/$ID", "{username}/stack/99", "99"),
27+
("$USERNAME/stack/$ID", "refs/remote/origin/{username}/stack/99", "99"),
28+
],
29+
)
30+
def test_get_branch_id(username, template, branch_name, expected):
31+
branch_name = branch_name.format(username=username)
32+
assert get_branch_id(template, branch_name) == expected
33+
34+
35+
@pytest.mark.parametrize(
36+
"template,branch_name",
37+
[
38+
("feature/$ID/desc", "feature/abc/desc"),
39+
("feature/$ID/desc", "wrong/format"),
40+
("$USERNAME/stack/$ID", "{username}/main/99"),
41+
],
42+
)
43+
def test_get_branch_id_no_match(username, template, branch_name):
44+
branch_name = branch_name.format(username=username)
45+
assert get_branch_id(template, branch_name) is None
46+
47+
48+
def test_generate_branch_name():
49+
template = "feature/$ID/description"
50+
assert generate_branch_name(template, "123") == "feature/123/description"
51+
52+
53+
def test_get_taken_branch_ids():
54+
template = "User/stack/$ID"
55+
refs = [
56+
"refs/remotes/origin/User/stack/104",
57+
"refs/remotes/origin/User/stack/105",
58+
"refs/remotes/origin/User/stack/134",
59+
]
60+
assert get_taken_branch_ids(refs, template) == [104, 105, 134]
61+
refs = ["User/stack/104", "User/stack/105", "User/stack/134"]
62+
assert get_taken_branch_ids(refs, template) == [104, 105, 134]
63+
refs = ["User/stack/104", "AAAA/stack/105", "User/stack/134", "User/stack/bbb"]
64+
assert get_taken_branch_ids(refs, template) == [104, 134]
65+
66+
67+
def test_generate_available_branch_name():
68+
template = "User/stack/$ID"
69+
refs = [
70+
"refs/remotes/origin/User/stack/104",
71+
"refs/remotes/origin/User/stack/105",
72+
"refs/remotes/origin/User/stack/134",
73+
]
74+
assert generate_available_branch_name(refs, template) == "User/stack/135"
75+
refs = []
76+
assert generate_available_branch_name(refs, template) == "User/stack/1"
77+
template = "User-stack-$ID"
78+
refs = [
79+
"refs/remotes/origin/User-stack-104",
80+
"refs/remotes/origin/User-stack-105",
81+
"refs/remotes/origin/User-stack-134",
82+
]
83+
assert generate_available_branch_name(refs, template) == "User-stack-135"

0 commit comments

Comments
 (0)