Skip to content

Commit

Permalink
Merge branch 'release/2024.2' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghaiZhang committed Dec 18, 2024
2 parents 644b0f3 + c007046 commit 7b98cc4
Show file tree
Hide file tree
Showing 22 changed files with 538 additions and 120 deletions.
5 changes: 5 additions & 0 deletions Source/Plugins/Core/com.equella.core/plugin-jpf.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6227,5 +6227,10 @@
<parameter id="platform" value="ENTRA_ID" />
<parameter id="order" value="1100" />
</extension>
<extension plugin-id="com.tle.core.usermanagement" point-id="oidcUserDir" id="oktaUserDir">
<parameter id="bean" value="bean:com.tle.core.usermanagement.OktaUserDirectory" />
<parameter id="platform" value="OKTA" />
<parameter id="order" value="1200" />
</extension>
<!-- endregion -->
</plugin>
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@

package com.tle.core.oauthclient

import java.time.Instant
import java.util.concurrent._
import cats.effect.IO
import sttp.client._
import sttp.client.circe._
import com.tle.common.institution.CurrentInstitution
import com.tle.web.oauth.OAuthWebConstants
import fs2.Stream
import io.circe.generic.semiauto._
import io.circe.{Decoder, Encoder}
import com.tle.core.httpclient._
import com.tle.core.oauthclient.OAuthClientService.{replicatedCache, responseToState}
import com.tle.core.oauthclient.OAuthTokenCacheHelper.{buildCacheKey, cacheId, requestToken}
import com.tle.legacy.LegacyGuice
import sttp.model.Header
import sttp.model.StatusCode
import com.tle.web.oauth.OAuthWebConstants
import fs2.Stream
import io.circe.generic.semiauto._
import io.circe.{Decoder, Encoder}
import sttp.client._
import sttp.client.circe._
import sttp.model.{Header, StatusCode}

import java.net.URI
import java.time.Instant
import java.util.concurrent._

object OAuthTokenType extends Enumeration {
val Bearer, EquellaApi = Value
Expand All @@ -48,25 +49,55 @@ object OAuthTokenType extends Enumeration {

}

/** Represent a POST request to obtain an OAuth2 Access Token.
/** Structure for the bare minimum of data required in one of the OAuth2 Client Authentication
* methods listed below:
*
* @param authTokenUrl
* The URL used to obtain an access token from the selected Identity Provider
* @param clientId
* Client ID used to get an Access Token to be used in API calls
* @param clientSecret
* Client Secret used with `clientId` to get an Access Token
* @param data
* Any additional data required in the request (e.g. 'audience' for Auth0)
* - Client Secret
* - Mutual TLS
* - Private Key JWT
*
* Currently, this structure provides the support for the Client Secret and Private Key JWT, but it
* can be extended to support Mutual TLS in the future if needed.
*
* Reference: https://oauth.net/2/client-authentication/
*/
sealed trait TokenRequest {

/** The URL used to obtain an access token from the selected Identity Provider
*/
def authTokenUrl: String

/** Client ID used to get an Access Token to be used in API calls
*/
def clientId: String

/** Any additional data required in the request (e.g. 'audience' for Auth0)
*/
def data: Option[Map[String, String]]

/** Build a unique key to identity the token request.
*/
final def key: String = clientId + authTokenUrl
}

/** Data structure for requesting an OAuth2 Access Token using the Client Secret method.
*/
case class TokenRequest(
final case class ClientSecretTokenRequest(
authTokenUrl: String,
clientId: String,
clientSecret: String,
data: Option[Map[String, String]] = None
) {
def key: String = clientId + authTokenUrl
}
) extends TokenRequest

/** Data structure for requesting an OAuth2 Access Token using the Private Key JWT method.
*/
final case class AssertionTokenRequest(
authTokenUrl: String,
clientId: String,
assertion: String,
assertionType: URI,
data: Option[Map[String, String]] = None
) extends TokenRequest

case class OAuthTokenState(
token: String,
Expand Down Expand Up @@ -108,16 +139,31 @@ object OAuthTokenCacheHelper {
val body = token.data
.getOrElse(Map.empty)
.toSeq :+ (OAuthWebConstants.PARAM_GRANT_TYPE -> OAuthWebConstants.GRANT_TYPE_CREDENTIALS)
val postRequest = basicRequest.auth
.basic(token.clientId, token.clientSecret)
.body(body: _*)
.response(asJsonAlways[OAuthTokenResponse])
.post(uri"${token.authTokenUrl}")

val postRequest = token match {
case req: ClientSecretTokenRequest =>
basicRequest.auth
.basic(req.clientId, req.clientSecret)
.body(body: _*)
case req: AssertionTokenRequest =>
val assertionParams = Seq(
OAuthWebConstants.PARAM_CLIENT_ASSERTION_TYPE -> req.assertionType.toString,
OAuthWebConstants.PARAM_CLIENT_ASSERTION -> req.assertion
)

val fullBody = body :++ assertionParams
basicRequest.body(fullBody: _*)
}

lazy val newOAuthTokenState =
sttpBackend
.flatMap(implicit backend => postRequest.send())
.map(r => r.body.fold(de => throw de.error, responseToState))
.flatMap(implicit backend =>
postRequest
.response(asJson[OAuthTokenResponse])
.post(uri"${token.authTokenUrl}")
.send()
)
.map(r => r.body.fold(de => throw de, responseToState))
.unsafeRunSync()

// Save the token in both cache and DB.
Expand Down Expand Up @@ -179,7 +225,7 @@ object OAuthClientService {
clientSecret: String,
request: Request[T, Stream[IO, Byte]]
): Response[T] = {
val tokenRequest = TokenRequest(authTokenUrl, clientId, clientSecret)
val tokenRequest = ClientSecretTokenRequest(authTokenUrl, clientId, clientSecret)
val token = tokenForClient(tokenRequest)
val res = requestWithToken(request, token.token, token.tokenType)
if (res.code == StatusCode.Unauthorized) removeToken(tokenRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import cats.implicits._
import com.tle.common.Pair
import com.tle.common.usermanagement.user.valuebean.UserBean
import com.tle.core.oauthclient.{OAuthClientService, OAuthTokenState, TokenRequest}
import com.tle.integration.oidc.idp.GenericIdentityProviderDetails
import com.tle.plugins.ump.UserDirectory
import io.circe.Decoder
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -58,8 +57,6 @@ abstract class ApiUserDirectory extends OidcUserDirectory {

protected def toUserBean(user: USER): UserBean

override type IDP = GenericIdentityProviderDetails

override protected type AuthResult = OAuthTokenState

/** Use the provided Identity Provider details and user ID to build a full URL that points to the
Expand Down Expand Up @@ -124,9 +121,9 @@ abstract class ApiUserDirectory extends OidcUserDirectory {
query: String
): Pair[UserDirectory.ChainResult, util.Collection[UserBean]] = {
lazy val search: String => (IDP, OAuthTokenState) => Either[Throwable, USERS] =
query =>
q =>
(idp, tokenState) =>
requestWithToken[USERS](userListEndpoint(idp, query), tokenState, requestHeaders)
requestWithToken[USERS](userListEndpoint(idp, q), tokenState, requestHeaders)

val users = execute(search(query))
.leftMap(LOGGER.error(s"Failed to search users", _))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package com.tle.core.usermanagement

import com.tle.common.usermanagement.user.valuebean.{DefaultUserBean, UserBean}
import com.tle.core.guice.Bind
import com.tle.core.oauthclient.TokenRequest
import com.tle.core.oauthclient.ClientSecretTokenRequest
import com.tle.integration.oidc.idp.{GenericIdentityProviderDetails, IdentityProviderPlatform}
import io.circe.Decoder
import io.circe.generic.semiauto.deriveDecoder
Expand Down Expand Up @@ -55,6 +55,8 @@ class Auth0UserDirectory extends ApiUserDirectory {

override val targetPlatform: IdentityProviderPlatform.Value = IdentityProviderPlatform.AUTH0

override type IDP = GenericIdentityProviderDetails

override protected type USER = Auth0User

override protected implicit val userDecoder: Decoder[Auth0User] = deriveDecoder[Auth0User]
Expand Down Expand Up @@ -93,8 +95,8 @@ class Auth0UserDirectory extends ApiUserDirectory {
*
* Reference link: https://auth0.com/docs/secure/tokens/access-tokens/get-access-tokens.
*/
override protected def tokenRequest(idp: IDP): TokenRequest =
TokenRequest(
override protected def tokenRequest(idp: IDP): ClientSecretTokenRequest =
ClientSecretTokenRequest(
idp.commonDetails.tokenUrl.toString,
idp.apiClientId,
idp.apiClientSecret,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package com.tle.core.usermanagement

import com.tle.common.usermanagement.user.valuebean.{DefaultUserBean, UserBean}
import com.tle.core.guice.Bind
import com.tle.core.oauthclient.TokenRequest
import com.tle.core.oauthclient.ClientSecretTokenRequest
import com.tle.integration.oidc.OpenIDConnectParams
import com.tle.integration.oidc.idp.{GenericIdentityProviderDetails, IdentityProviderPlatform}
import io.circe.Decoder
Expand Down Expand Up @@ -53,6 +53,8 @@ class EntraIdUserDirectory extends ApiUserDirectory {
override protected val targetPlatform: IdentityProviderPlatform.Value =
IdentityProviderPlatform.ENTRA_ID

override type IDP = GenericIdentityProviderDetails

override protected type USER = EntraIdUser

override protected type USERS = EntraIdUserList
Expand Down Expand Up @@ -125,8 +127,8 @@ class EntraIdUserDirectory extends ApiUserDirectory {
* Reference link:
* https://learn.microsoft.com/en-us/entra/identity-platform/scopes-oidc#the-default-scope
*/
override protected def tokenRequest(idp: IDP): TokenRequest =
TokenRequest(
override protected def tokenRequest(idp: IDP): ClientSecretTokenRequest =
ClientSecretTokenRequest(
idp.commonDetails.tokenUrl.toString,
idp.apiClientId,
idp.apiClientSecret,
Expand Down
Loading

0 comments on commit 7b98cc4

Please sign in to comment.