Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ project.ext.externalDependency = [
'avroCompiler': 'org.apache.avro:avro-compiler:1.11.4',
'awsGlueSchemaRegistrySerde': 'software.amazon.glue:schema-registry-serde:1.1.23',
'awsMskIamAuth': 'software.amazon.msk:aws-msk-iam-auth:2.3.2',
'awsSdk2Bom': 'software.amazon.awssdk:bom:2.23.6',
'awsS3': "software.amazon.awssdk:s3:$awsSdk2Version",
'awsSecretsManagerJdbc': 'com.amazonaws.secretsmanager:aws-secretsmanager-jdbc:1.0.15',
'awsPostgresIamAuth': 'software.amazon.jdbc:aws-advanced-jdbc-wrapper:2.5.4',
Expand Down
5 changes: 5 additions & 0 deletions datahub-graphql-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ dependencies {
implementation externalDependency.guava
implementation externalDependency.opentelemetryAnnotations

implementation platform(externalDependency.awsSdk2Bom)
implementation 'software.amazon.awssdk:regions'
implementation 'software.amazon.awssdk:sts'
implementation 'software.amazon.awssdk:s3'

implementation externalDependency.slf4jApi
implementation externalDependency.springContext
compileOnly externalDependency.lombok
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ private Constants() {}
public static final String TIMESERIES_SCHEMA_FILE = "timeseries.graphql";
public static final String LOGICAL_SCHEMA_FILE = "logical.graphql";
public static final String SETTINGS_SCHEMA_FILE = "settings.graphql";
public static final String FILES_SCHEMA_FILE = "files.graphql";

public static final String QUERY_SCHEMA_FILE = "query.graphql";
public static final String TEMPLATE_SCHEMA_FILE = "template.graphql";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import com.linkedin.datahub.graphql.resolvers.entity.EntityPrivilegesResolver;
import com.linkedin.datahub.graphql.resolvers.entity.versioning.LinkAssetVersionResolver;
import com.linkedin.datahub.graphql.resolvers.entity.versioning.UnlinkAssetVersionResolver;
import com.linkedin.datahub.graphql.resolvers.files.GetPresignedUploadUrlResolver;
import com.linkedin.datahub.graphql.resolvers.form.BatchAssignFormResolver;
import com.linkedin.datahub.graphql.resolvers.form.BatchRemoveFormResolver;
import com.linkedin.datahub.graphql.resolvers.form.CreateDynamicFormAssignmentResolver;
Expand Down Expand Up @@ -315,6 +316,7 @@
import com.linkedin.datahub.graphql.types.test.TestType;
import com.linkedin.datahub.graphql.types.versioning.VersionSetType;
import com.linkedin.datahub.graphql.types.view.DataHubViewType;
import com.linkedin.datahub.graphql.util.S3Util;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.entity.client.SystemEntityClient;
import com.linkedin.metadata.client.UsageStatsJavaClient;
Expand Down Expand Up @@ -489,6 +491,8 @@ public class GmsGraphQLEngine {
private final GraphQLConfiguration graphQLConfiguration;
private final MetricUtils metricUtils;

private final S3Util s3Util;

private final BusinessAttributeType businessAttributeType;

/** A list of GraphQL Plugins that extend the core engine */
Expand Down Expand Up @@ -621,6 +625,7 @@ public GmsGraphQLEngine(final GmsGraphQLEngineArgs args) {
this.dataHubPageModuleType = new PageModuleType(entityClient);
this.graphQLConfiguration = args.graphQLConfiguration;
this.metricUtils = args.metricUtils;
this.s3Util = args.s3Util;

this.businessAttributeType = new BusinessAttributeType(entityClient);
// Init Lists
Expand Down Expand Up @@ -845,7 +850,8 @@ public GraphQLEngine.Builder builder() {
.addSchema(fileBasedSchema(QUERY_SCHEMA_FILE))
.addSchema(fileBasedSchema(TEMPLATE_SCHEMA_FILE))
.addSchema(fileBasedSchema(MODULE_SCHEMA_FILE))
.addSchema(fileBasedSchema(SETTINGS_SCHEMA_FILE));
.addSchema(fileBasedSchema(SETTINGS_SCHEMA_FILE))
.addSchema(fileBasedSchema(FILES_SCHEMA_FILE));

for (GmsGraphQLPlugin plugin : this.graphQLPlugins) {
List<String> pluginSchemaFiles = plugin.getSchemaFiles();
Expand Down Expand Up @@ -1108,7 +1114,11 @@ private void configureQueryResolvers(final RuntimeWiring.Builder builder) {
new DocPropagationSettingsResolver(this.settingsService))
.dataFetcher(
"globalHomePageSettings",
new GlobalHomePageSettingsResolver(this.settingsService)));
new GlobalHomePageSettingsResolver(this.settingsService))
.dataFetcher(
"getPresignedUploadUrl",
new GetPresignedUploadUrlResolver(
this.s3Util, this.datahubConfiguration.getS3())));
}

private DataFetcher getEntitiesResolver() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.datahub.authorization.role.RoleService;
import com.linkedin.datahub.graphql.analytics.service.AnalyticsService;
import com.linkedin.datahub.graphql.featureflags.FeatureFlags;
import com.linkedin.datahub.graphql.util.S3Util;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.entity.client.SystemEntityClient;
import com.linkedin.metadata.client.UsageStatsJavaClient;
Expand Down Expand Up @@ -97,6 +98,6 @@ public class GmsGraphQLEngineArgs {
PageModuleService pageModuleService;
boolean systemTelemetryEnabled;
MetricUtils metricUtils;

S3Util s3Util;
// any fork specific args should go below this line
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package com.linkedin.datahub.graphql.resolvers.files;

import static com.linkedin.datahub.graphql.resolvers.ResolverUtils.bindArgument;

import com.linkedin.common.urn.UrnUtils;
import com.linkedin.datahub.graphql.QueryContext;
import com.linkedin.datahub.graphql.concurrency.GraphQLConcurrencyUtils;
import com.linkedin.datahub.graphql.exception.AuthorizationException;
import com.linkedin.datahub.graphql.generated.GetPresignedUploadUrlInput;
import com.linkedin.datahub.graphql.generated.GetPresignedUploadUrlResponse;
import com.linkedin.datahub.graphql.generated.UploadDownloadScenario;
import com.linkedin.datahub.graphql.resolvers.mutate.DescriptionUtils;
import com.linkedin.datahub.graphql.util.S3Util;
import com.linkedin.metadata.config.S3Configuration;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Slf4j
@Component
public class GetPresignedUploadUrlResolver
implements DataFetcher<CompletableFuture<GetPresignedUploadUrlResponse>> {

private final S3Util s3Util;
private final S3Configuration s3Configuration;

public GetPresignedUploadUrlResolver(S3Util s3Util, S3Configuration s3Configuration) {
this.s3Util = s3Util;
this.s3Configuration = s3Configuration;
}

@Override
public CompletableFuture<GetPresignedUploadUrlResponse> get(DataFetchingEnvironment environment)
throws Exception {
if (s3Util == null) {
throw new IllegalArgumentException("S3Util isn't provided");
}

String bucketName = s3Configuration.getBucketName();

if (bucketName == null || bucketName.isEmpty()) {
throw new IllegalArgumentException("Bucket name isn't provided");
}

final GetPresignedUploadUrlInput input =
bindArgument(environment.getArgument("input"), GetPresignedUploadUrlInput.class);

final QueryContext context = environment.getContext();

validateInput(context, input);

String newFileId = generateNewFileId(input);
String s3Key = getS3Key(input, newFileId, bucketName);
String contentType = input.getContentType();

return GraphQLConcurrencyUtils.supplyAsync(
() -> {
String presignedUploadUrl =
s3Util.generatePresignedUploadUrl(
bucketName,
s3Key,
s3Configuration.getPresignedUploadUrlExpirationSeconds(),
contentType);

GetPresignedUploadUrlResponse result = new GetPresignedUploadUrlResponse();
result.setUrl(presignedUploadUrl);
result.setFileId(newFileId);
return result;
},
this.getClass().getSimpleName(),
"get");
}

private void validateInput(final QueryContext context, final GetPresignedUploadUrlInput input) {
UploadDownloadScenario scenario = input.getScenario();

if (scenario == UploadDownloadScenario.ASSET_DOCUMENTATION) {
validateInputForAssetDocumentationScenario(context, input);
}
}

private void validateInputForAssetDocumentationScenario(
final QueryContext context, final GetPresignedUploadUrlInput input) {
String assetUrn = input.getAssetUrn();

if (assetUrn == null) {
throw new IllegalArgumentException("assetUrn is required for ASSET_DOCUMENTATION scenario");
}

if (!DescriptionUtils.isAuthorizedToUpdateDescription(context, UrnUtils.getUrn(assetUrn))) {
throw new AuthorizationException("Unauthorized to edit documentation for asset: " + assetUrn);
}
}

private String generateNewFileId(final GetPresignedUploadUrlInput input) {
return String.format("%s-%s", UUID.randomUUID().toString(), input.getFileName());
}

private String getS3Key(
final GetPresignedUploadUrlInput input, final String fileId, final String bucketName) {
UploadDownloadScenario scenario = input.getScenario();

if (scenario == UploadDownloadScenario.ASSET_DOCUMENTATION) {
return String.format(
"%s/%s/%s",
s3Configuration.getBucketName(), s3Configuration.getAssetPathPrefix(), fileId);
} else {
throw new IllegalArgumentException("Unsupported upload scenario: " + scenario);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package com.linkedin.datahub.graphql.util;

import com.linkedin.entity.client.EntityClient;
import java.time.Duration;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.*;
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest;
import software.amazon.awssdk.services.s3.presigner.model.PresignedGetObjectRequest;
import software.amazon.awssdk.services.s3.presigner.model.PresignedPutObjectRequest;
import software.amazon.awssdk.services.s3.presigner.model.PutObjectPresignRequest;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;

@Slf4j
public class S3Util {

private final S3Client s3Client;
private final EntityClient entityClient;

// Optional S3Presigner for testing purposes
@Nullable private final S3Presigner s3Presigner;

public S3Util(@Nonnull S3Client s3Client, @Nonnull EntityClient entityClient) {
this(s3Client, entityClient, null);
}

public S3Util(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's double check @anshbansal has reviewed this file since he was the original implementer

@Nonnull S3Client s3Client,
@Nonnull EntityClient entityClient,
@Nullable S3Presigner s3Presigner) {
this.s3Client = s3Client;
this.entityClient = entityClient;
this.s3Presigner = s3Presigner;
}

public S3Util(
@Nonnull EntityClient entityClient, @Nonnull StsClient stsClient, @Nonnull String roleArn) {
this(entityClient, stsClient, roleArn, null);
}

public S3Util(
@Nonnull EntityClient entityClient,
@Nonnull StsClient stsClient,
@Nonnull String roleArn,
@Nullable S3Presigner s3Presigner) {
this.entityClient = entityClient;
this.s3Presigner = s3Presigner;
this.s3Client = createS3Client(stsClient, roleArn);
}

/** Creates S3Client with StsAssumeRoleCredentialsProvider for automatic credential refresh. */
private static S3Client createS3Client(@Nonnull StsClient stsClient, @Nonnull String roleArn) {
try {
log.info("Creating S3Client for role: {}", roleArn);

StsAssumeRoleCredentialsProvider credentialsProvider =
StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(r -> r.roleArn(roleArn).roleSessionName("s3-session"))
.asyncCredentialUpdateEnabled(true) // Enable background credential refresh
.build();

var clientBuilder = S3Client.builder().credentialsProvider(credentialsProvider);

// Configure endpoint URL if provided (for LocalStack or custom S3 endpoints)
String endpointUrl = System.getenv("AWS_ENDPOINT_URL");
if (endpointUrl != null && !endpointUrl.isEmpty()) {
clientBuilder.endpointOverride(java.net.URI.create(endpointUrl));
// Force path-style access for LocalStack compatibility
clientBuilder.forcePathStyle(true);
}

S3Client client = clientBuilder.build();
log.info("Successfully created S3Client for role: {}", roleArn);
return client;

} catch (Exception e) {
log.error("Failed to create S3 client: roleArn={}", roleArn, e);
throw new RuntimeException("Failed to create S3 clien: " + e.getMessage(), e);
}
}

private S3Presigner getPresigner() {
if (this.s3Presigner != null) {
return this.s3Presigner;
}

return S3Presigner.builder()
.credentialsProvider(s3Client.serviceClientConfiguration().credentialsProvider())
.region(s3Client.serviceClientConfiguration().region())
.build();
}

/**
* Generate a pre-signed URL for downloading an S3 object
*
* @param bucket The S3 bucket name
* @param key The S3 object key
* @param expirationSeconds The expiration time in seconds
* @return The pre-signed URL
*/
public String generatePresignedDownloadUrl(
@Nonnull String bucket, @Nonnull String key, int expirationSeconds) {
try {
// Create a pre-signer using the same configuration as the S3 client
try (S3Presigner presigner = getPresigner()) {

// Create the GetObjectRequest
GetObjectRequest getObjectRequest =
GetObjectRequest.builder().bucket(bucket).key(key).build();

// Create the presign request
GetObjectPresignRequest presignRequest =
GetObjectPresignRequest.builder()
.signatureDuration(Duration.ofSeconds(expirationSeconds))
.getObjectRequest(getObjectRequest)
.build();

// Generate the presigned URL
PresignedGetObjectRequest presignedRequest = presigner.presignGetObject(presignRequest);
return presignedRequest.url().toString();
}
} catch (Exception e) {
log.error("Failed to generate presigned URL for bucket: {}, key: {}", bucket, key, e);
throw new RuntimeException("Failed to generate presigned URL: " + e.getMessage(), e);
}
}

/**
* Generate a pre-signed URL for uploading an S3 object
*
* @param bucket The S3 bucket name
* @param key The S3 object key
* @param expirationSeconds The expiration time in seconds
* @param contentType The content type of the object to be uploaded (e.g., "image/jpeg",
* "application/pdf")
* @return The pre-signed URL
*/
public String generatePresignedUploadUrl(
@Nonnull String bucket,
@Nonnull String key,
int expirationSeconds,
@Nullable String contentType) {
try {
// Create a pre-signer using the same configuration as the S3 client
try (S3Presigner presigner = getPresigner()) {

// Create the PutObjectRequest
PutObjectRequest putObjectRequest =
PutObjectRequest.builder().bucket(bucket).contentType(contentType).key(key).build();

// Create the presign request
PutObjectPresignRequest presignRequest =
PutObjectPresignRequest.builder()
.signatureDuration(Duration.ofSeconds(expirationSeconds))
.putObjectRequest(putObjectRequest)
.build();

// Generate the presigned URL
PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest);
return presignedRequest.url().toString();
}
} catch (Exception e) {
log.error("Failed to generate presigned upload URL for bucket: {}, key: {}", bucket, key, e);
throw new RuntimeException("Failed to generate presigned upload URL: " + e.getMessage(), e);
}
}
}
Loading
Loading