@@ -286,7 +286,7 @@ def pprint(self, links: bool):
286
286
s = b (self .commit .commit_id ()[:8 ])
287
287
pr_string = None
288
288
if self .has_pr ():
289
- pr_string = blue ("#" + self .pr . split ( "/" )[ - 1 ] )
289
+ pr_string = blue ("#" + last ( self .pr ) )
290
290
else :
291
291
pr_string = red ("no PR" )
292
292
branch_string = None
@@ -384,16 +384,6 @@ def split_header(s: str) -> List[CommitHeader]:
384
384
return [CommitHeader (h ) for h in s .split ("\0 " )[:- 1 ]]
385
385
386
386
387
- def is_valid_ref (ref : str , branch_name_template : str ) -> bool :
388
- ref = ref .strip ("'" )
389
-
390
- branch_name_base = get_branch_name_base (branch_name_template )
391
- splits = ref .rsplit ("/" , 1 )
392
- if len (splits ) < 2 :
393
- return False
394
- return splits [- 2 ].endswith (branch_name_base ) and splits [- 1 ].isnumeric ()
395
-
396
-
397
387
def last (ref : str , sep : str = "/" ) -> str :
398
388
return ref .rsplit (sep , 1 )[- 1 ]
399
389
@@ -559,6 +549,13 @@ def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) ->
559
549
return True
560
550
561
551
552
+ def fix_branch_name_template (branch_name_template : str ):
553
+ if "$ID" not in branch_name_template :
554
+ return f"{ branch_name_template } /$ID"
555
+
556
+ return branch_name_template
557
+
558
+
562
559
@cache
563
560
def get_branch_name_base (branch_name_template : str ):
564
561
username = get_gh_username ()
@@ -568,31 +565,54 @@ def get_branch_name_base(branch_name_template: str):
568
565
return branch_name_base
569
566
570
567
568
+ def get_branch_id (branch_name_template : str , branch_name : str ):
569
+ branch_name_base = get_branch_name_base (branch_name_template )
570
+ pattern = branch_name_base .replace (r"$ID" , r"(\d+)" )
571
+ match = re .search (pattern , branch_name )
572
+ if match :
573
+ return match .group (1 )
574
+ return None
575
+
576
+
577
+ def generate_branch_name (branch_name_template : str , branch_id : int ):
578
+ branch_name_base = get_branch_name_base (branch_name_template )
579
+ branch_name = branch_name_base .replace (r"$ID" , branch_id )
580
+ return branch_name
581
+
582
+
583
+ def get_taken_branch_ids (refs : List [str ], branch_name_template : str ) -> List [int ]:
584
+ branch_ids = list (get_branch_id (branch_name_template , ref ) for ref in refs )
585
+ branch_ids = [int (branch_id ) for branch_id in branch_ids if branch_id is not None ]
586
+ return branch_ids
587
+
588
+
589
+ def generate_available_branch_name (refs : List [str ], branch_name_template : str ) -> str :
590
+ branch_ids = get_taken_branch_ids (refs , branch_name_template )
591
+ max_ref_num = max (branch_ids ) if branch_ids else 0
592
+ new_branch_id = max_ref_num + 1
593
+ return generate_branch_name (branch_name_template , str (new_branch_id ))
594
+
595
+
571
596
def get_available_branch_name (remote : str , branch_name_template : str ) -> str :
572
597
branch_name_base = get_branch_name_base (branch_name_template )
573
598
599
+ git_command_branch_template = branch_name_base .replace (r"$ID" , "*" )
574
600
refs = get_command_output (
575
601
[
576
602
"git" ,
577
603
"for-each-ref" ,
578
- f"refs/remotes/{ remote } /{ branch_name_base } " ,
604
+ f"refs/remotes/{ remote } /{ git_command_branch_template } " ,
579
605
"--format='%(refname)'" ,
580
606
]
581
607
).split ()
582
608
583
- def check_ref (ref ):
584
- return is_valid_ref (ref , branch_name_base )
585
-
586
- refs = list (filter (check_ref , refs ))
587
- max_ref_num = max (int (last (ref .strip ("'" ))) for ref in refs ) if refs else 0
588
- new_branch_id = max_ref_num + 1
609
+ refs = list ([ref .strip ("'" ) for ref in refs ])
610
+ return generate_available_branch_name (refs , branch_name_template )
589
611
590
- return f"{ branch_name_base } /{ new_branch_id } "
591
612
592
-
593
- def get_next_available_branch_name (name : str ) -> str :
594
- base , id = name .rsplit ("/" , 1 )
595
- return f"{ base } /{ int (id ) + 1 } "
613
+ def get_next_available_branch_name (branch_name_template : str , name : str ) -> str :
614
+ id = get_branch_id (branch_name_template , name )
615
+ return generate_branch_name (branch_name_template , str (int (id ) + 1 ))
596
616
597
617
598
618
def set_head_branches (
@@ -604,7 +624,9 @@ def set_head_branches(
604
624
available_name = get_available_branch_name (remote , branch_name_template )
605
625
for e in filter (lambda e : not e .has_head (), st ):
606
626
e .head = available_name
607
- available_name = get_next_available_branch_name (available_name )
627
+ available_name = get_next_available_branch_name (
628
+ branch_name_template , available_name
629
+ )
608
630
609
631
610
632
def init_local_branches (
@@ -1359,6 +1381,8 @@ def main():
1359
1381
parser .print_help ()
1360
1382
return
1361
1383
1384
+ # Make sure "$ID" is present in the branch name template and append it if not
1385
+ args .branch_name_template = fix_branch_name_template (args .branch_name_template )
1362
1386
common_args = CommonArgs .from_args (args )
1363
1387
1364
1388
if common_args .verbose :
@@ -1377,7 +1401,9 @@ def main():
1377
1401
1378
1402
if args .command in ["submit" , "export" ]:
1379
1403
if args .stash :
1380
- run_shell_command (["git" , "stash" , "save" ], quiet = not common_args .verbose )
1404
+ run_shell_command (
1405
+ ["git" , "stash" , "save" ], quiet = not common_args .verbose
1406
+ )
1381
1407
command_submit (
1382
1408
common_args ,
1383
1409
args .draft ,
0 commit comments