diff --git a/cli/oauth2.go b/cli/oauth2.go index 657fffa4..df0a8248 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "strings" + "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" @@ -63,11 +64,6 @@ type OAuth2CallbackInfo struct { ErrorDescription string } -type OAuth2Listener struct { - Socket net.Listener - callbackCh chan OAuth2CallbackInfo -} - func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { info := OAuth2CallbackInfo{ Error: r.FormValue("error"), @@ -79,36 +75,66 @@ func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { return info, nil } +// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise. +type OAuth2Listener struct { + socket net.Listener + once sync.Once + callbackCh chan OAuth2CallbackInfo +} + func NewOAuth2Listener(socket net.Listener) OAuth2Listener { return OAuth2Listener{ - Socket: socket, + socket: socket, + // This channel is only ever closed if a successful request is received. + // If the caller closes the socket, then that channel will leak resources. + // + // This probably indicates a problem with the way this struct is constructed: + // The channel should be 'bound' to the lifetime of the socket. + // + // Still, it's a minor resource waste, so we don't care that much. + // + // We can't have a Close() function on this struct, because the caller could call Close() before a request is received, + // which would result on a send on a closed channel - which will cause a panic. + // + // The correct thing to do is probably modify this constructor to instead: + // * Accept a context + // * Return a channel + // * Close the channel when the context expires or it receives a request (whichever is first). + // + // Unfortunately, this is challenging to do while also ensuring that this struct adheres to the http.Handler interface. + // + // The correct solution probably means we change this function signature to + // + // func(context.Context, state string) (http.Handler, <-chan string) + // + // or the less re-usable/testable + // + // func(context.Context, socket net.Listener, state string) <-chan string + // + // The real problem we have with the current layout is that the struct can be put into invalid states, and the easiest way to avoid that + // is to simply not allow state manipulation at all using a closure. callbackCh: make(chan OAuth2CallbackInfo), } } -func (o OAuth2Listener) Close() error { - if o.callbackCh != nil { +func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // This can sometimes be called multiple times, depending on the browser. + // We will simply ignore any other requests and only serve the first. + o.once.Do(func() { + info, err := ParseCallbackRequest(r) + if err == nil { + // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. + o.callbackCh <- info + } close(o.callbackCh) - } - if o.Socket != nil { - return o.Socket.Close() - } - return nil -} - -func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - info, err := ParseCallbackRequest(r) - if err == nil { - // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. - o.callbackCh <- info - } + }) - // This is displayed to the end user in their browser. + // We still want to provide feedback to the end-user. fmt.Fprintln(w, "You may close this window now.") } -func (o OAuth2Listener) Listen() error { - err := http.Serve(o.Socket, o) +func (o *OAuth2Listener) Listen() error { + err := http.Serve(o.socket, o) if errors.Is(err, http.ErrServerClosed) { return nil } @@ -116,7 +142,7 @@ func (o OAuth2Listener) Listen() error { return err } -func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { +func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { select { case info := <-o.callbackCh: if info.Error != "" { @@ -218,6 +244,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe if err != nil { return nil, err } + defer sock.Close() _, port, err := net.SplitHostPort(sock.Addr().String()) if err != nil { @@ -232,7 +259,6 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe ) listener := NewOAuth2Listener(sock) - defer listener.Close() // This error can be ignored. go listener.Listen() diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index c2e97af5..ece42843 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -49,7 +49,6 @@ func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { cancel() assert.Equal(t, expectedCode, code) - assert.NoError(t, listener.Close()) } func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { @@ -57,7 +56,6 @@ func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) _, err := listener.WaitForAuthorizationCode(deadline, "") assert.ErrorIs(t, context.DeadlineExceeded, err) - assert.NoError(t, listener.Close()) } // This test is going to be flaky because processes may open ports outside of our control.