@@ -12,7 +12,15 @@ import (
12
12
"github.com/github/gh-combine/internal/github"
13
13
)
14
14
15
- func CombinePRs (ctx context.Context , graphQlClient * api.GraphQLClient , restClient * api.RESTClient , repo github.Repo , pulls github.Pulls ) error {
15
+ // Updated RESTClientInterface to match the method signatures of api.RESTClient
16
+ type RESTClientInterface interface {
17
+ Post (endpoint string , body io.Reader , response interface {}) error
18
+ Get (endpoint string , response interface {}) error
19
+ Delete (endpoint string , response interface {}) error
20
+ Patch (endpoint string , body io.Reader , response interface {}) error
21
+ }
22
+
23
+ func CombinePRs (ctx context.Context , graphQlClient * api.GraphQLClient , restClient RESTClientInterface , repo github.Repo , pulls github.Pulls ) error {
16
24
// Define the combined branch name
17
25
workingBranchName := combineBranchName + workingBranchSuffix
18
26
@@ -87,7 +95,7 @@ func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClien
87
95
// Create the combined PR
88
96
prBody := generatePRBody (combinedPRs , mergeFailedPRs )
89
97
prTitle := "Combined PRs"
90
- err = createPullRequest (ctx , restClient , repo , prTitle , combineBranchName , repoDefaultBranch , prBody )
98
+ err = createPullRequest (ctx , restClient , repo , prTitle , combineBranchName , repoDefaultBranch , prBody , addLabels , addAssignees )
91
99
if err != nil {
92
100
return fmt .Errorf ("failed to create combined PR: %w" , err )
93
101
}
@@ -102,7 +110,7 @@ func isMergeConflictError(err error) bool {
102
110
}
103
111
104
112
// Find the default branch of a repository
105
- func getDefaultBranch (ctx context.Context , client * api. RESTClient , repo github.Repo ) (string , error ) {
113
+ func getDefaultBranch (ctx context.Context , client RESTClientInterface , repo github.Repo ) (string , error ) {
106
114
var repoInfo struct {
107
115
DefaultBranch string `json:"default_branch"`
108
116
}
@@ -115,7 +123,7 @@ func getDefaultBranch(ctx context.Context, client *api.RESTClient, repo github.R
115
123
}
116
124
117
125
// Get the SHA of a given branch
118
- func getBranchSHA (ctx context.Context , client * api. RESTClient , repo github.Repo , branch string ) (string , error ) {
126
+ func getBranchSHA (ctx context.Context , client RESTClientInterface , repo github.Repo , branch string ) (string , error ) {
119
127
var ref struct {
120
128
Object struct {
121
129
SHA string `json:"sha"`
@@ -148,13 +156,13 @@ func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
148
156
}
149
157
150
158
// deleteBranch deletes a branch in the repository
151
- func deleteBranch (ctx context.Context , client * api. RESTClient , repo github.Repo , branch string ) error {
159
+ func deleteBranch (ctx context.Context , client RESTClientInterface , repo github.Repo , branch string ) error {
152
160
endpoint := fmt .Sprintf ("repos/%s/%s/git/refs/heads/%s" , repo .Owner , repo .Repo , branch )
153
161
return client .Delete (endpoint , nil )
154
162
}
155
163
156
164
// createBranch creates a new branch in the repository
157
- func createBranch (ctx context.Context , client * api. RESTClient , repo github.Repo , branch , sha string ) error {
165
+ func createBranch (ctx context.Context , client RESTClientInterface , repo github.Repo , branch , sha string ) error {
158
166
endpoint := fmt .Sprintf ("repos/%s/%s/git/refs" , repo .Owner , repo .Repo )
159
167
payload := map [string ]string {
160
168
"ref" : "refs/heads/" + branch ,
@@ -168,7 +176,7 @@ func createBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
168
176
}
169
177
170
178
// mergeBranch merges a branch into the base branch
171
- func mergeBranch (ctx context.Context , client * api. RESTClient , repo github.Repo , base , head string ) error {
179
+ func mergeBranch (ctx context.Context , client RESTClientInterface , repo github.Repo , base , head string ) error {
172
180
endpoint := fmt .Sprintf ("repos/%s/%s/merges" , repo .Owner , repo .Repo )
173
181
payload := map [string ]string {
174
182
"base" : base ,
@@ -182,7 +190,7 @@ func mergeBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
182
190
}
183
191
184
192
// updateRef updates a branch to point to the latest commit of another branch
185
- func updateRef (ctx context.Context , client * api. RESTClient , repo github.Repo , branch , sourceBranch string ) error {
193
+ func updateRef (ctx context.Context , client RESTClientInterface , repo github.Repo , branch , sourceBranch string ) error {
186
194
// Get the SHA of the source branch
187
195
var ref struct {
188
196
Object struct {
@@ -208,15 +216,25 @@ func updateRef(ctx context.Context, client *api.RESTClient, repo github.Repo, br
208
216
return client .Patch (endpoint , body , nil )
209
217
}
210
218
211
- // createPullRequest creates a new pull request
212
- func createPullRequest (ctx context.Context , client * api.RESTClient , repo github.Repo , title , head , base , body string ) error {
219
+ func createPullRequest (ctx context.Context , client RESTClientInterface , repo github.Repo , title , head , base , body string , labels , assignees []string ) error {
213
220
endpoint := fmt .Sprintf ("repos/%s/%s/pulls" , repo .Owner , repo .Repo )
214
- payload := map [string ]string {
215
- "title" : title ,
216
- "head" : head ,
217
- "base" : base ,
218
- "body" : body ,
221
+ payload := map [string ]interface {} {
222
+ "title" : title ,
223
+ "head" : head ,
224
+ "base" : base ,
225
+ "body" : body ,
219
226
}
227
+
228
+ // Add labels if provided
229
+ if len (labels ) > 0 {
230
+ payload ["labels" ] = labels
231
+ }
232
+
233
+ // Add assignees if provided
234
+ if len (assignees ) > 0 {
235
+ payload ["assignees" ] = assignees
236
+ }
237
+
220
238
requestBody , err := encodePayload (payload )
221
239
if err != nil {
222
240
return fmt .Errorf ("failed to encode payload: %w" , err )
0 commit comments