@@ -20,12 +20,13 @@ def get_slug_from_remote_url(remote_url: str) -> str:
20
20
return potential_slug .removesuffix (".git" )
21
21
22
22
23
- @contextlib .contextmanager
24
- def transitioning_branches (
25
- repo : Repo , branch_prefix : str , branch_suffix : str = "" , force : bool = True
26
- ) -> Generator [tuple [Head , Head ], None , None ]:
23
+ def get_current_branch (repo : Repo ) -> Head :
24
+ remote = repo .remote ("origin" )
27
25
if repo .head .is_detached :
28
- from_branch = next ((branch for branch in repo .branches if branch .commit == repo .head .commit ), None )
26
+ from_branch = next (
27
+ (branch for branch in remote .refs if branch .commit == repo .head .commit and branch .remote_head != "HEAD" ),
28
+ None ,
29
+ )
29
30
else :
30
31
from_branch = repo .active_branch
31
32
@@ -35,16 +36,27 @@ def transitioning_branches(
35
36
"Make sure repository is not in a detached HEAD state with additional commits."
36
37
)
37
38
38
- next_branch_name = f"{ branch_prefix } { from_branch .name } { branch_suffix } "
39
+ return from_branch
40
+
41
+
42
+ @contextlib .contextmanager
43
+ def transitioning_branches (
44
+ repo : Repo , branch_prefix : str , branch_suffix : str = "" , force : bool = True
45
+ ) -> Generator [tuple [str , str ], None , None ]:
46
+ from_branch = get_current_branch (repo )
47
+ from_branch_name = from_branch .name if not from_branch .is_remote () else from_branch .remote_head
48
+ next_branch_name = f"{ branch_prefix } { from_branch_name } { branch_suffix } "
39
49
if next_branch_name in repo .heads and not force :
40
- raise ValueError (f'Branch "{ next_branch_name } " already exists.' )
50
+ raise ValueError (f'Local Branch "{ next_branch_name } " already exists.' )
51
+ if next_branch_name in repo .remote ("origin" ).refs and not force :
52
+ raise ValueError (f'Remote Branch "{ next_branch_name } " already exists.' )
41
53
42
54
logger .info (f'Creating new branch "{ next_branch_name } ".' )
43
55
to_branch = repo .create_head (next_branch_name , force = force )
44
56
45
57
try :
46
58
to_branch .checkout ()
47
- yield from_branch , to_branch
59
+ yield from_branch_name , next_branch_name
48
60
finally :
49
61
from_branch .checkout ()
50
62
@@ -137,7 +149,9 @@ def run(self) -> dict:
137
149
repo = git .Repo (Path .cwd ())
138
150
if not self .enabled :
139
151
logger .debug ("Branch creation is disabled." )
140
- return dict (target_branch = repo .active_branch .name )
152
+ from_branch = get_current_branch (repo )
153
+ from_branch_name = from_branch .name if not from_branch .is_remote () else from_branch .remote_head
154
+ return dict (target_branch = from_branch_name )
141
155
142
156
modified_files = {modified_code_file ["path" ] for modified_code_file in self .modified_code_files }
143
157
@@ -153,6 +167,6 @@ def run(self) -> dict:
153
167
154
168
logger .info (f"Run completed { self .__class__ .__name__ } " )
155
169
return dict (
156
- base_branch = from_branch . name ,
157
- target_branch = to_branch . name ,
170
+ base_branch = from_branch ,
171
+ target_branch = to_branch ,
158
172
)
0 commit comments