Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions common/lib/authentication/iam_authentication_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,26 @@ import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
import { ClientWrapper } from "../client_wrapper";
import { RegionUtils } from "../utils/region_utils";
import { CanReleaseResources } from "../can_release_resources";
import { RdsUrlType } from "../utils/rds_url_type";
import { RdsUtils } from "../utils/rds_utils";
import { GDBRegionUtils } from "../utils/gdb_region_utils";

export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
protected static readonly tokenCache = new Map<string, TokenInfo>();
private readonly telemetryFactory;
private readonly fetchTokenCounter;
private pluginService: PluginService;
private readonly pluginService: PluginService;
private readonly rdsUtils: RdsUtils = new RdsUtils();
protected regionUtils: RegionUtils;
protected readonly iamAuthUtils: IamAuthUtils;

constructor(pluginService: PluginService) {
constructor(pluginService: PluginService, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) {
super();
this.pluginService = pluginService;
this.telemetryFactory = this.pluginService.getTelemetryFactory();
this.fetchTokenCounter = this.telemetryFactory.createCounter("iam.fetchTokenCount");
this.iamAuthUtils = iamAuthUtils;
}

getSubscribedMethods(): Set<string> {
Expand Down Expand Up @@ -74,14 +81,22 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements
throw new AwsWrapperError(`${WrapperProperties.USER} is null or empty`);
}

const host = IamAuthUtils.getIamHost(props, hostInfo);
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
const host = this.iamAuthUtils.getIamHost(props, hostInfo);
const port = this.iamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);

const type: RdsUrlType = this.rdsUtils.identifyRdsType(host.host);
this.regionUtils = type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER ? new GDBRegionUtils() : new RegionUtils();
const region: string | null = await this.regionUtils.getRegion(WrapperProperties.IAM_REGION.name, host, props);

if (!region) {
throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unableToDetermineRegion", WrapperProperties.IAM_REGION.name));
}

const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
if (tokenExpirationSec < 0) {
throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero"));
}
const cacheKey: string = IamAuthUtils.getCacheKey(port, user, host, region);
const cacheKey: string = this.iamAuthUtils.getCacheKey(port, user, host.host, region);

const tokenInfo = IamAuthenticationPlugin.tokenCache.get(cacheKey);
const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired();
Expand All @@ -91,8 +106,8 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements
WrapperProperties.PASSWORD.set(props, tokenInfo.token);
} else {
const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000;
const token = await IamAuthUtils.generateAuthenticationToken(
host,
const token = await this.iamAuthUtils.generateAuthenticationToken(
host.host,
port,
region,
user,
Expand All @@ -118,8 +133,8 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements
// Try to generate a new token and try to connect again

const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000;
const token = await IamAuthUtils.generateAuthenticationToken(
host,
const token = await this.iamAuthUtils.generateAuthenticationToken(
host.host,
port,
region,
user,
Expand Down
1 change: 1 addition & 0 deletions common/lib/plugin_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { TelemetryTraceLevel } from "./utils/telemetry/telemetry_trace_level";
import { ConnectionProvider } from "./connection_provider";
import { ConnectionPluginFactory } from "./plugin_factory";
import { ConfigurationProfile } from "./profile/configuration_profile";
import { BaseSamlAuthPlugin } from "./plugins/federated_auth/saml_auth_plugin";

type PluginFunc<T> = (plugin: ConnectionPlugin, targetFunc: () => Promise<T>) => Promise<T>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ import { sleep } from "../../utils/utils";
import { CustomEndpointMonitor, CustomEndpointMonitorImpl } from "./custom_endpoint_monitor_impl";
import { SubscribedMethodHelper } from "../../utils/subscribed_method_helper";
import { CanReleaseResources } from "../../can_release_resources";
import { RdsUrlType } from "../../utils/rds_url_type";
import { GDBRegionUtils } from "../../utils/gdb_region_utils";

export class CustomEndpointPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
private static readonly TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter";
private static SUBSCRIBED_METHODS: Set<string> = new Set<string>(SubscribedMethodHelper.NETWORK_BOUND_METHODS);
private static readonly CACHE_CLEANUP_NANOS = BigInt(60_000_000_000);
private static readonly regionUtils: RegionUtils = new RegionUtils();

private static readonly rdsUtils = new RdsUtils();
protected static readonly monitors: SlidingExpirationCache<string, CustomEndpointMonitor> = new SlidingExpirationCache(
Expand Down Expand Up @@ -106,7 +109,7 @@ export class CustomEndpointPlugin extends AbstractConnectionPlugin implements Ca
throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.errorParsingEndpointIdentifier", this.customEndpointHostInfo.host));
}

this.region = RegionUtils.getRegion(props.get(WrapperProperties.CUSTOM_ENDPOINT_REGION.name), this.customEndpointHostInfo.host);
this.region = await CustomEndpointPlugin.regionUtils.getRegion(WrapperProperties.CUSTOM_ENDPOINT_REGION.name, this.customEndpointHostInfo, props);
if (!this.region) {
throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.unableToDetermineRegion", WrapperProperties.CUSTOM_ENDPOINT_REGION.name));
}
Expand Down
115 changes: 5 additions & 110 deletions common/lib/plugins/federated_auth/federated_auth_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,118 +14,13 @@
limitations under the License.
*/

import { AbstractConnectionPlugin } from "../../abstract_connection_plugin";
import { PluginService } from "../../plugin_service";
import { RdsUtils } from "../../utils/rds_utils";
import { HostInfo } from "../../host_info";
import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils";
import { WrapperProperties } from "../../wrapper_property";
import { logger } from "../../../logutils";
import { AwsWrapperError } from "../../utils/errors";
import { Messages } from "../../utils/messages";
import { CredentialsProviderFactory } from "./credentials_provider_factory";
import { SamlUtils } from "../../utils/saml_utils";
import { ClientWrapper } from "../../client_wrapper";
import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter";
import { RegionUtils } from "../../utils/region_utils";
import { CanReleaseResources } from "../../can_release_resources";
import { BaseSamlAuthPlugin } from "./saml_auth_plugin";
import { IamAuthUtils } from "../../utils/iam_auth_utils";

export class FederatedAuthPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
protected static readonly tokenCache = new Map<string, TokenInfo>();
protected rdsUtils: RdsUtils = new RdsUtils();
protected pluginService: PluginService;
private static readonly subscribedMethods = new Set<string>(["connect", "forceConnect"]);
private readonly credentialsProviderFactory: CredentialsProviderFactory;
private readonly fetchTokenCounter: TelemetryCounter;

public getSubscribedMethods(): Set<string> {
return FederatedAuthPlugin.subscribedMethods;
}

constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory) {
super();
this.credentialsProviderFactory = credentialsProviderFactory;
this.pluginService = pluginService;
this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter("federatedAuth.fetchToken.count");
}

connect(
hostInfo: HostInfo,
props: Map<string, any>,
isInitialConnection: boolean,
connectFunc: () => Promise<ClientWrapper>
): Promise<ClientWrapper> {
return this.connectInternal(hostInfo, props, connectFunc);
}

forceConnect(
hostInfo: HostInfo,
props: Map<string, any>,
isInitialConnection: boolean,
forceConnectFunc: () => Promise<ClientWrapper>
): Promise<ClientWrapper> {
return this.connectInternal(hostInfo, props, forceConnectFunc);
}

async connectInternal(hostInfo: HostInfo, props: Map<string, any>, connectFunc: () => Promise<ClientWrapper>): Promise<ClientWrapper> {
SamlUtils.checkIdpCredentialsWithFallback(props);

const host = IamAuthUtils.getIamHost(props, hostInfo);
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);

const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region);
const tokenInfo = FederatedAuthPlugin.tokenCache.get(cacheKey);

const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired();

if (isCachedToken && tokenInfo) {
logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token));
WrapperProperties.PASSWORD.set(props, tokenInfo.token);
} else {
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
}
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props));
this.pluginService.updateConfigWithProperties(props);

try {
return await connectFunc();
} catch (e) {
if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) {
throw e;
}
try {
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
return await connectFunc();
} catch (e: any) {
throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message));
}
}
}

public async updateAuthenticationToken(hostInfo: HostInfo, props: Map<string, any>, region: string, cacheKey: string, iamHost: string) {
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
if (tokenExpirationSec < 0) {
throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero"));
}
const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000;
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
const token = await IamAuthUtils.generateAuthenticationToken(
iamHost,
port,
region,
WrapperProperties.DB_USER.get(props),
await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props),
this.pluginService
);
this.fetchTokenCounter.inc();
logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token));
WrapperProperties.PASSWORD.set(props, token);
FederatedAuthPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
}

releaseResources(): Promise<void> {
FederatedAuthPlugin.tokenCache.clear();
return;
export class FederatedAuthPlugin extends BaseSamlAuthPlugin {
constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) {
super(pluginService, credentialsProviderFactory, "federatedAuth.fetchToken.count", iamAuthUtils);
}
}
117 changes: 5 additions & 112 deletions common/lib/plugins/federated_auth/okta_auth_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,120 +14,13 @@
limitations under the License.
*/

import { AbstractConnectionPlugin } from "../../abstract_connection_plugin";
import { HostInfo } from "../../host_info";
import { SamlUtils } from "../../utils/saml_utils";
import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils";
import { PluginService } from "../../plugin_service";
import { CredentialsProviderFactory } from "./credentials_provider_factory";
import { RdsUtils } from "../../utils/rds_utils";
import { WrapperProperties } from "../../wrapper_property";
import { logger } from "../../../logutils";
import { Messages } from "../../utils/messages";
import { AwsWrapperError } from "../../utils/errors";
import { ClientWrapper } from "../../client_wrapper";
import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter";
import { RegionUtils } from "../../utils/region_utils";
import { CanReleaseResources } from "../../can_release_resources";
import { BaseSamlAuthPlugin } from "./saml_auth_plugin";
import { IamAuthUtils } from "../../utils/iam_auth_utils";

export class OktaAuthPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
protected static readonly tokenCache = new Map<string, TokenInfo>();
private static readonly subscribedMethods = new Set<string>(["connect", "forceConnect"]);
protected pluginService: PluginService;
protected rdsUtils = new RdsUtils();
private readonly credentialsProviderFactory: CredentialsProviderFactory;
private readonly fetchTokenCounter: TelemetryCounter;

constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory) {
super();
this.pluginService = pluginService;
this.credentialsProviderFactory = credentialsProviderFactory;
this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter("oktaAuth.fetchToken.count");
}

public getSubscribedMethods(): Set<string> {
return OktaAuthPlugin.subscribedMethods;
}

connect(
hostInfo: HostInfo,
props: Map<string, any>,
isInitialConnection: boolean,
connectFunc: () => Promise<ClientWrapper>
): Promise<ClientWrapper> {
return this.connectInternal(hostInfo, props, connectFunc);
}

forceConnect(
hostInfo: HostInfo,
props: Map<string, any>,
isInitialConnection: boolean,
connectFunc: () => Promise<ClientWrapper>
): Promise<ClientWrapper> {
return this.connectInternal(hostInfo, props, connectFunc);
}

async connectInternal(hostInfo: HostInfo, props: Map<string, any>, connectFunc: () => Promise<ClientWrapper>): Promise<ClientWrapper> {
SamlUtils.checkIdpCredentialsWithFallback(props);

const host = IamAuthUtils.getIamHost(props, hostInfo);
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
const region = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);

const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region);
const tokenInfo = OktaAuthPlugin.tokenCache.get(cacheKey);

const isCachedToken = tokenInfo !== undefined && !tokenInfo.isExpired();

if (isCachedToken) {
logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token));
WrapperProperties.PASSWORD.set(props, tokenInfo.token);
} else {
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
}
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props));
this.pluginService.updateConfigWithProperties(props);

try {
return await connectFunc();
} catch (e: any) {
if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) {
logger.debug(Messages.get("Authentication.connectError", e.message));
throw e;
}
try {
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
return await connectFunc();
} catch (e: any) {
throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message));
}
}
}

public async updateAuthenticationToken(hostInfo: HostInfo, props: Map<string, any>, region: string, cacheKey: string, iamHost): Promise<void> {
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
if (tokenExpirationSec < 0) {
throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero"));
}
const tokenExpiry = Date.now() + tokenExpirationSec * 1000;
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
this.fetchTokenCounter.inc();
const token = await IamAuthUtils.generateAuthenticationToken(
iamHost,
port,
region,
WrapperProperties.DB_USER.get(props),
await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props),
this.pluginService
);
logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token));
WrapperProperties.PASSWORD.set(props, token);
this.pluginService.updateConfigWithProperties(props);
OktaAuthPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
}

releaseResources(): Promise<void> {
OktaAuthPlugin.tokenCache.clear();
return;
export class OktaAuthPlugin extends BaseSamlAuthPlugin {
constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) {
super(pluginService, credentialsProviderFactory, "oktaAuth.fetchToken.count", iamAuthUtils);
}
}
Loading
Loading