Skip to content

Commit 790baac

Browse files
committed
add logic for adding labels and assignees to the resulting PR that gets created by this CLI
1 parent 2b73a35 commit 790baac

33 files changed

+8695
-18
lines changed

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ require (
1010
github.com/cli/go-gh/v2 v2.12.0
1111
github.com/cli/shurcooL-graphql v0.0.4
1212
github.com/spf13/cobra v1.9.1
13+
github.com/stretchr/testify v1.10.0
1314
)
1415

1516
require (
1617
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
1718
github.com/cli/safeexec v1.0.0 // indirect
19+
github.com/davecgh/go-spew v1.1.1 // indirect
1820
github.com/fatih/color v1.7.0 // indirect
1921
github.com/henvic/httpretty v0.0.6 // indirect
2022
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -23,6 +25,7 @@ require (
2325
github.com/mattn/go-colorable v0.1.13 // indirect
2426
github.com/mattn/go-isatty v0.0.20 // indirect
2527
github.com/muesli/termenv v0.16.0 // indirect
28+
github.com/pmezard/go-difflib v1.0.0 // indirect
2629
github.com/rivo/uniseg v0.4.7 // indirect
2730
github.com/spf13/pflag v1.0.6 // indirect
2831
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
4242
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
4343
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
4444
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
45-
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
46-
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
45+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
46+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
4747
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8=
4848
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI=
4949
golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

internal/cmd/combine_prs.go

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ import (
1212
"github.com/github/gh-combine/internal/github"
1313
)
1414

15-
func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient *api.RESTClient, repo github.Repo, pulls github.Pulls) error {
15+
// Updated RESTClientInterface to match the method signatures of api.RESTClient
16+
type RESTClientInterface interface {
17+
Post(endpoint string, body io.Reader, response interface{}) error
18+
Get(endpoint string, response interface{}) error
19+
Delete(endpoint string, response interface{}) error
20+
Patch(endpoint string, body io.Reader, response interface{}) error
21+
}
22+
23+
func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient RESTClientInterface, repo github.Repo, pulls github.Pulls) error {
1624
// Define the combined branch name
1725
workingBranchName := combineBranchName + workingBranchSuffix
1826

@@ -87,7 +95,7 @@ func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClien
8795
// Create the combined PR
8896
prBody := generatePRBody(combinedPRs, mergeFailedPRs)
8997
prTitle := "Combined PRs"
90-
err = createPullRequest(ctx, restClient, repo, prTitle, combineBranchName, repoDefaultBranch, prBody)
98+
err = createPullRequest(ctx, restClient, repo, prTitle, combineBranchName, repoDefaultBranch, prBody, addLabels, addAssignees)
9199
if err != nil {
92100
return fmt.Errorf("failed to create combined PR: %w", err)
93101
}
@@ -102,7 +110,7 @@ func isMergeConflictError(err error) bool {
102110
}
103111

104112
// Find the default branch of a repository
105-
func getDefaultBranch(ctx context.Context, client *api.RESTClient, repo github.Repo) (string, error) {
113+
func getDefaultBranch(ctx context.Context, client RESTClientInterface, repo github.Repo) (string, error) {
106114
var repoInfo struct {
107115
DefaultBranch string `json:"default_branch"`
108116
}
@@ -115,7 +123,7 @@ func getDefaultBranch(ctx context.Context, client *api.RESTClient, repo github.R
115123
}
116124

117125
// Get the SHA of a given branch
118-
func getBranchSHA(ctx context.Context, client *api.RESTClient, repo github.Repo, branch string) (string, error) {
126+
func getBranchSHA(ctx context.Context, client RESTClientInterface, repo github.Repo, branch string) (string, error) {
119127
var ref struct {
120128
Object struct {
121129
SHA string `json:"sha"`
@@ -148,13 +156,13 @@ func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
148156
}
149157

150158
// deleteBranch deletes a branch in the repository
151-
func deleteBranch(ctx context.Context, client *api.RESTClient, repo github.Repo, branch string) error {
159+
func deleteBranch(ctx context.Context, client RESTClientInterface, repo github.Repo, branch string) error {
152160
endpoint := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.Owner, repo.Repo, branch)
153161
return client.Delete(endpoint, nil)
154162
}
155163

156164
// createBranch creates a new branch in the repository
157-
func createBranch(ctx context.Context, client *api.RESTClient, repo github.Repo, branch, sha string) error {
165+
func createBranch(ctx context.Context, client RESTClientInterface, repo github.Repo, branch, sha string) error {
158166
endpoint := fmt.Sprintf("repos/%s/%s/git/refs", repo.Owner, repo.Repo)
159167
payload := map[string]string{
160168
"ref": "refs/heads/" + branch,
@@ -168,7 +176,7 @@ func createBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
168176
}
169177

170178
// mergeBranch merges a branch into the base branch
171-
func mergeBranch(ctx context.Context, client *api.RESTClient, repo github.Repo, base, head string) error {
179+
func mergeBranch(ctx context.Context, client RESTClientInterface, repo github.Repo, base, head string) error {
172180
endpoint := fmt.Sprintf("repos/%s/%s/merges", repo.Owner, repo.Repo)
173181
payload := map[string]string{
174182
"base": base,
@@ -182,7 +190,7 @@ func mergeBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
182190
}
183191

184192
// updateRef updates a branch to point to the latest commit of another branch
185-
func updateRef(ctx context.Context, client *api.RESTClient, repo github.Repo, branch, sourceBranch string) error {
193+
func updateRef(ctx context.Context, client RESTClientInterface, repo github.Repo, branch, sourceBranch string) error {
186194
// Get the SHA of the source branch
187195
var ref struct {
188196
Object struct {
@@ -208,15 +216,25 @@ func updateRef(ctx context.Context, client *api.RESTClient, repo github.Repo, br
208216
return client.Patch(endpoint, body, nil)
209217
}
210218

211-
// createPullRequest creates a new pull request
212-
func createPullRequest(ctx context.Context, client *api.RESTClient, repo github.Repo, title, head, base, body string) error {
219+
func createPullRequest(ctx context.Context, client RESTClientInterface, repo github.Repo, title, head, base, body string, labels, assignees []string) error {
213220
endpoint := fmt.Sprintf("repos/%s/%s/pulls", repo.Owner, repo.Repo)
214-
payload := map[string]string{
215-
"title": title,
216-
"head": head,
217-
"base": base,
218-
"body": body,
221+
payload := map[string]interface{}{
222+
"title": title,
223+
"head": head,
224+
"base": base,
225+
"body": body,
219226
}
227+
228+
// Add labels if provided
229+
if len(labels) > 0 {
230+
payload["labels"] = labels
231+
}
232+
233+
// Add assignees if provided
234+
if len(assignees) > 0 {
235+
payload["assignees"] = assignees
236+
}
237+
220238
requestBody, err := encodePayload(payload)
221239
if err != nil {
222240
return fmt.Errorf("failed to encode payload: %w", err)

internal/cmd/combine_prs_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/github/gh-combine/internal/github"
9+
)
10+
11+
func TestCreatePullRequest(t *testing.T) {
12+
client := &MockRESTClient{
13+
PostFunc: func(endpoint string, body interface{}, response interface{}) error {
14+
return nil
15+
},
16+
}
17+
repo := github.Repo{
18+
Owner: "test-owner",
19+
Repo: "test-repo",
20+
}
21+
title := "Test PR"
22+
head := "test-branch"
23+
base := "main"
24+
body := "This is a test PR."
25+
labels := []string{"bug", "enhancement"}
26+
assignees := []string{"octocat", "hubot"}
27+
28+
err := createPullRequest(context.Background(), client, repo, title, head, base, body, labels, assignees)
29+
assert.NoError(t, err)
30+
}

internal/cmd/mock_restclient.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
)
8+
9+
type MockRESTClient struct {
10+
PostFunc func(endpoint string, body interface{}, response interface{}) error
11+
GetFunc func(endpoint string, response interface{}) error
12+
DeleteFunc func(endpoint string, response interface{}) error
13+
PatchFunc func(endpoint string, body io.Reader, response interface{}) error
14+
}
15+
16+
// Updated the Post method to match the RESTClientInterface signature
17+
func (m *MockRESTClient) Post(endpoint string, body io.Reader, response interface{}) error {
18+
if m.PostFunc != nil {
19+
return m.PostFunc(endpoint, body, response)
20+
}
21+
return nil
22+
}
23+
24+
func (m *MockRESTClient) Get(endpoint string, response interface{}) error {
25+
if m.GetFunc != nil {
26+
return m.GetFunc(endpoint, response)
27+
}
28+
return nil
29+
}
30+
31+
func (m *MockRESTClient) Delete(endpoint string, response interface{}) error {
32+
if m.DeleteFunc != nil {
33+
return m.DeleteFunc(endpoint, response)
34+
}
35+
return nil
36+
}
37+
38+
// Updated the Patch method to match the RESTClientInterface signature
39+
func (m *MockRESTClient) Patch(endpoint string, body io.Reader, response interface{}) error {
40+
if m.PatchFunc != nil {
41+
return m.PatchFunc(endpoint, body, response)
42+
}
43+
return nil
44+
}
45+
46+
func (m *MockRESTClient) RequestWithContext(ctx context.Context, method string, path string, body io.Reader) (*http.Response, error) {
47+
return nil, nil
48+
}
49+
50+
func (m *MockRESTClient) Request(method string, path string, body io.Reader) (*http.Response, error) {
51+
return nil, nil
52+
}
53+
54+
func (m *MockRESTClient) DoWithContext(ctx context.Context, method string, path string, body io.Reader, response interface{}) error {
55+
return nil
56+
}
57+
58+
func (m *MockRESTClient) Do(method string, path string, body io.Reader, response interface{}) error {
59+
return nil
60+
}
61+
62+
func (m *MockRESTClient) Put(path string, body io.Reader, resp interface{}) error {
63+
return nil
64+
}

internal/cmd/root.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,13 @@ func processRepository(ctx context.Context, client *api.RESTClient, graphQlClien
289289

290290
Logger.Debug("Matched PRs", "repo", repo, "count", len(matchedPRs))
291291

292+
// Wrap the *api.RESTClient to implement RESTClientInterface
293+
restClientWrapper := struct {
294+
RESTClientInterface
295+
}{client}
296+
292297
// Combine the PRs
293-
err = CombinePRs(ctx, graphQlClient, client, repo, matchedPRs)
298+
err = CombinePRs(ctx, graphQlClient, restClientWrapper, repo, matchedPRs)
294299
if err != nil {
295300
return fmt.Errorf("failed to combine PRs: %w", err)
296301
}

vendor/github.com/davecgh/go-spew/LICENSE

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)