Skip to content

Commit 728c1db

Browse files
committed
feat(auth): allow additional CA certs for OAuth2
Allow specifying additional TLS CA certificates for OAuth2 authenticator. This is useful when the cert for the OAuth2 services is not installed in the system and needs to be specified explicitly. Signed-off-by: Sergei Trofimov <[email protected]>
1 parent 8a3a730 commit 728c1db

File tree

3 files changed

+53
-22
lines changed

3 files changed

+53
-22
lines changed

auth/oauth2.go

+13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"errors"
88
"fmt"
9+
"net/http"
910
"net/url"
1011
"strings"
1112
"time"
@@ -20,6 +21,7 @@ type Oauth2Authenticator struct {
2021
ClientSecret string
2122
Username string
2223
Password string
24+
CertPaths []string
2325

2426
Token *oauth2.Token
2527
}
@@ -31,6 +33,7 @@ func (o *Oauth2Authenticator) Configure(cfg map[string]interface{}) error {
3133
ClientSecret string `mapstructure:"client_secret"`
3234
Username string `mapstructure:"username"`
3335
Password string `mapstructure:"password"`
36+
CertPaths []string `mapstructure:"ca_cert"`
3437
Rest map[string]interface{} `mapstructure:",remain"`
3538
}{}
3639

@@ -43,6 +46,7 @@ func (o *Oauth2Authenticator) Configure(cfg map[string]interface{}) error {
4346
o.TokenURL = decoded.TokenURL
4447
o.Username = decoded.Username
4548
o.Password = decoded.Password
49+
o.CertPaths = decoded.CertPaths
4650

4751
if err := o.validate(); err != nil {
4852
return err
@@ -90,6 +94,15 @@ func (o *Oauth2Authenticator) obtainToken() (*oauth2.Token, error) {
9094
},
9195
}
9296

97+
if len(o.CertPaths) > 0 {
98+
transport, err := NewTLSTransport(o.CertPaths)
99+
if err != nil {
100+
return nil, err
101+
}
102+
client := &http.Client{Transport: transport}
103+
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
104+
}
105+
93106
return conf.PasswordCredentialsToken(ctx, o.Username, o.Password)
94107
}
95108

auth/tls.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2024 Contributors to the Veraison project.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package auth
4+
5+
import (
6+
"crypto/tls"
7+
"crypto/x509"
8+
"fmt"
9+
"net/http"
10+
"os"
11+
)
12+
13+
// NewTLSTransport returns a pointer to a new http.Transport with TLS config
14+
// initilaized with system certs as well as specified certPaths.
15+
func NewTLSTransport(certPaths []string) (*http.Transport, error) {
16+
certPool, err := x509.SystemCertPool()
17+
if err != nil {
18+
return nil, err
19+
}
20+
21+
for _, certPath := range certPaths {
22+
rawCert, err := os.ReadFile(certPath)
23+
if err != nil {
24+
return nil, fmt.Errorf("could not read cert: %w", err)
25+
}
26+
27+
if ok := certPool.AppendCertsFromPEM(rawCert); !ok {
28+
return nil, fmt.Errorf("invalid cert in %s", certPath)
29+
}
30+
}
31+
32+
return &http.Transport{
33+
TLSClientConfig: &tls.Config{
34+
RootCAs: certPool,
35+
MinVersion: tls.VersionTLS12,
36+
},
37+
}, nil
38+
}

common/client.go

+2-22
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ package common
66
import (
77
"bytes"
88
"crypto/tls"
9-
"crypto/x509"
109
"fmt"
1110
"io"
1211
"net/http"
13-
"os"
1412
"time"
1513

1614
"github.com/veraison/apiclient/auth"
@@ -50,30 +48,12 @@ func NewInsecureTLSClient(a auth.IAuthenticator) *Client {
5048
// The client will use the provided IAuthenticator for requests, if it is not
5149
// nil.
5250
func NewTLSClient(a auth.IAuthenticator, certPaths []string) (*Client, error) {
53-
certPool, err := x509.SystemCertPool()
51+
transport, err := auth.NewTLSTransport(certPaths)
5452
if err != nil {
5553
return nil, err
5654
}
5755

58-
for _, certPath := range certPaths {
59-
rawCert, err := os.ReadFile(certPath)
60-
if err != nil {
61-
return nil, fmt.Errorf("could not read cert: %w", err)
62-
}
63-
64-
if ok := certPool.AppendCertsFromPEM(rawCert); !ok {
65-
return nil, fmt.Errorf("invalid cert in %s", certPath)
66-
}
67-
}
68-
69-
transport := http.Transport{
70-
TLSClientConfig: &tls.Config{
71-
RootCAs: certPool,
72-
MinVersion: tls.VersionTLS12,
73-
},
74-
}
75-
76-
return NewClientWithTransport(a, &transport), nil
56+
return NewClientWithTransport(a, transport), nil
7757
}
7858

7959
// NewClientWithTransport instantiates a new Client with the specified transport and a fixed

0 commit comments

Comments
 (0)