Skip to content

Allow specifying the location for the branch number in the branch name template. #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ These arguments can be used with any subcommand:
- `-T, --target`: Remote target branch (default: "main")
- `--hyperlinks/--no-hyperlinks`: Enable/disable hyperlink support (default: enabled)
- `-V, --verbose`: Enable verbose output from Git subcommands (default: false)
- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack")
- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack"). The following variables are supported:
- `$USERNAME`: The username of the current user
- `$BRANCH`: The current branch name
- `$ID`: The location for the ID of the branch. The ID is determined by the order of creation of the branches. If `$ID` is not found in the template, the template will be appended with `/$ID`.

### Subcommands

Expand Down
72 changes: 48 additions & 24 deletions src/stack_pr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def pprint(self, links: bool):
s = b(self.commit.commit_id()[:8])
pr_string = None
if self.has_pr():
pr_string = blue("#" + self.pr.split("/")[-1])
pr_string = blue("#" + last(self.pr))
else:
pr_string = red("no PR")
branch_string = None
Expand Down Expand Up @@ -384,16 +384,6 @@ def split_header(s: str) -> List[CommitHeader]:
return [CommitHeader(h) for h in s.split("\0")[:-1]]


def is_valid_ref(ref: str, branch_name_template: str) -> bool:
ref = ref.strip("'")

branch_name_base = get_branch_name_base(branch_name_template)
splits = ref.rsplit("/", 1)
if len(splits) < 2:
return False
return splits[-2].endswith(branch_name_base) and splits[-1].isnumeric()


def last(ref: str, sep: str = "/") -> str:
return ref.rsplit(sep, 1)[-1]

Expand Down Expand Up @@ -559,6 +549,13 @@ def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) ->
return True


def fix_branch_name_template(branch_name_template: str):
if "$ID" not in branch_name_template:
return f"{branch_name_template}/$ID"

return branch_name_template


@cache
def get_branch_name_base(branch_name_template: str):
username = get_gh_username()
Expand All @@ -568,31 +565,54 @@ def get_branch_name_base(branch_name_template: str):
return branch_name_base


def get_branch_id(branch_name_template: str, branch_name: str):
branch_name_base = get_branch_name_base(branch_name_template)
pattern = branch_name_base.replace(r"$ID", r"(\d+)")
match = re.search(pattern, branch_name)
if match:
return match.group(1)
return None


def generate_branch_name(branch_name_template: str, branch_id: int):
branch_name_base = get_branch_name_base(branch_name_template)
branch_name = branch_name_base.replace(r"$ID", branch_id)
return branch_name


def get_taken_branch_ids(refs: List[str], branch_name_template: str) -> List[int]:
branch_ids = list(get_branch_id(branch_name_template, ref) for ref in refs)
branch_ids = [int(branch_id) for branch_id in branch_ids if branch_id is not None]
return branch_ids


def generate_available_branch_name(refs: List[str], branch_name_template: str) -> str:
branch_ids = get_taken_branch_ids(refs, branch_name_template)
max_ref_num = max(branch_ids) if branch_ids else 0
new_branch_id = max_ref_num + 1
return generate_branch_name(branch_name_template, str(new_branch_id))


def get_available_branch_name(remote: str, branch_name_template: str) -> str:
branch_name_base = get_branch_name_base(branch_name_template)

git_command_branch_template = branch_name_base.replace(r"$ID", "*")
refs = get_command_output(
[
"git",
"for-each-ref",
f"refs/remotes/{remote}/{branch_name_base}",
f"refs/remotes/{remote}/{git_command_branch_template}",
"--format='%(refname)'",
]
).split()

def check_ref(ref):
return is_valid_ref(ref, branch_name_base)

refs = list(filter(check_ref, refs))
max_ref_num = max(int(last(ref.strip("'"))) for ref in refs) if refs else 0
new_branch_id = max_ref_num + 1
refs = list([ref.strip("'") for ref in refs])
return generate_available_branch_name(refs, branch_name_template)

return f"{branch_name_base}/{new_branch_id}"


def get_next_available_branch_name(name: str) -> str:
base, id = name.rsplit("/", 1)
return f"{base}/{int(id) + 1}"
def get_next_available_branch_name(branch_name_template: str, name: str) -> str:
id = get_branch_id(branch_name_template, name)
return generate_branch_name(branch_name_template, str(int(id) + 1))


def set_head_branches(
Expand All @@ -604,7 +624,9 @@ def set_head_branches(
available_name = get_available_branch_name(remote, branch_name_template)
for e in filter(lambda e: not e.has_head(), st):
e.head = available_name
available_name = get_next_available_branch_name(available_name)
available_name = get_next_available_branch_name(
branch_name_template, available_name
)


def init_local_branches(
Expand Down Expand Up @@ -1359,6 +1381,8 @@ def main():
parser.print_help()
return

# Make sure "$ID" is present in the branch name template and append it if not
args.branch_name_template = fix_branch_name_template(args.branch_name_template)
common_args = CommonArgs.from_args(args)

if common_args.verbose:
Expand Down
83 changes: 83 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent / "src"))

from stack_pr.cli import (
get_branch_id,
generate_branch_name,
get_taken_branch_ids,
get_gh_username,
generate_available_branch_name,
)

import pytest


@pytest.fixture(scope="module")
def username():
return get_gh_username()


@pytest.mark.parametrize(
"template,branch_name,expected",
[
("feature-$ID-desc", "feature-123-desc", "123"),
("$USERNAME/stack/$ID", "{username}/stack/99", "99"),
("$USERNAME/stack/$ID", "refs/remote/origin/{username}/stack/99", "99"),
],
)
def test_get_branch_id(username, template, branch_name, expected):
branch_name = branch_name.format(username=username)
assert get_branch_id(template, branch_name) == expected


@pytest.mark.parametrize(
"template,branch_name",
[
("feature/$ID/desc", "feature/abc/desc"),
("feature/$ID/desc", "wrong/format"),
("$USERNAME/stack/$ID", "{username}/main/99"),
],
)
def test_get_branch_id_no_match(username, template, branch_name):
branch_name = branch_name.format(username=username)
assert get_branch_id(template, branch_name) is None


def test_generate_branch_name():
template = "feature/$ID/description"
assert generate_branch_name(template, "123") == "feature/123/description"


def test_get_taken_branch_ids():
template = "User/stack/$ID"
refs = [
"refs/remotes/origin/User/stack/104",
"refs/remotes/origin/User/stack/105",
"refs/remotes/origin/User/stack/134",
]
assert get_taken_branch_ids(refs, template) == [104, 105, 134]
refs = ["User/stack/104", "User/stack/105", "User/stack/134"]
assert get_taken_branch_ids(refs, template) == [104, 105, 134]
refs = ["User/stack/104", "AAAA/stack/105", "User/stack/134", "User/stack/bbb"]
assert get_taken_branch_ids(refs, template) == [104, 134]


def test_generate_available_branch_name():
template = "User/stack/$ID"
refs = [
"refs/remotes/origin/User/stack/104",
"refs/remotes/origin/User/stack/105",
"refs/remotes/origin/User/stack/134",
]
assert generate_available_branch_name(refs, template) == "User/stack/135"
refs = []
assert generate_available_branch_name(refs, template) == "User/stack/1"
template = "User-stack-$ID"
refs = [
"refs/remotes/origin/User-stack-104",
"refs/remotes/origin/User-stack-105",
"refs/remotes/origin/User-stack-134",
]
assert generate_available_branch_name(refs, template) == "User-stack-135"