diff --git a/main.go b/main.go index bb65edc..dfc69fe 100644 --- a/main.go +++ b/main.go @@ -96,13 +96,23 @@ func newCommentLoopChannel(ctx context.Context, apprv *approvalEnvironment, clie return channel } -func newGithubClient(ctx context.Context) *github.Client { +func newGithubClient(ctx context.Context) (*github.Client, error) { token := os.Getenv(envVarToken) ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: token}, ) tc := oauth2.NewClient(ctx, ts) - return github.NewClient(tc) + + serverUrl, serverUrlPresent := os.LookupEnv("GITHUB_SERVER_URL") + apiUrl, apiUrlPresent := os.LookupEnv("GITHUB_API_URL") + + if serverUrlPresent { + if ! apiUrlPresent { + apiUrl = serverUrl + } + return github.NewEnterpriseClient(apiUrl, serverUrl, tc) + } + return github.NewClient(tc), nil } func validateInput() error { @@ -148,7 +158,11 @@ func main() { repoOwner := os.Getenv(envVarRepoOwner) ctx := context.Background() - client := newGithubClient(ctx) + client, err := newGithubClient(ctx) + if err != nil { + fmt.Printf("error connecting to server: %v\n", err) + os.Exit(1) + } approvers, err := retrieveApprovers(client, repoOwner) if err != nil {