Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for setting module specific task locals #84

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
/.build
/.index-build
/Packages
/*.xcodeproj
xcuserdata/
Expand Down
110 changes: 77 additions & 33 deletions Sources/SotoCodeGeneratorLib/AwsService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ struct AwsService {

// separate by non-alphanumeric character, then capitalize the first letter of each component
// and join back together
let serviceName = sdkId
let serviceName =
sdkId
.components(separatedBy: CharacterSet.alphanumerics.inverted)
.map { $0.prefix(1).capitalized + $0.dropFirst() }
.joined()
Expand Down Expand Up @@ -153,7 +154,7 @@ struct AwsService {
/// filter operations list
mutating func filterOperations(_ filter: [String]) {
self.operations = self.operations.filter { key, _ in
return filter.contains(key.shapeName.toSwiftVariableCase())
filter.contains(key.shapeName.toSwiftVariableCase())
}
}

Expand Down Expand Up @@ -237,9 +238,10 @@ struct AwsService {
hostPrefix: endpointTrait?.hostPrefix,
deprecated: deprecatedTrait?.message,
streaming: streaming ? "ByteBuffer" : nil,
documentationUrl: nil, // added to comment
documentationUrl: nil, // added to comment
endpointRequired: requireEndpointDiscovery.map { OperationContext.DiscoverableEndpoint(required: $0) },
initParameters: initParamContext
initParameters: initParamContext,
taskLocals: generateTaskLocals(operation: operation)
)
}

Expand Down Expand Up @@ -276,6 +278,27 @@ struct AwsService {
}
}

func generateTaskLocals(
operation: OperationShape
) -> TaskLocalParameters? {
guard let staticParamsTrait = operation.trait(type: StaticContextParamsTrait.self) else { return nil }
let name: String
let possibleParameters: [String]
switch self.serviceEndpointPrefix {
case "s3":
name = "S3Middleware.$executionContext"
possibleParameters = ["UseS3ExpressControlEndpoint"]
default:
return nil
}
let parameters = staticParamsTrait.value
.filter { possibleParameters.contains($0.key) }
.compactMap { param in
param.value.dictionary?["value"].map { TaskLocalParameters.Parameter(key: param.key.toSwiftLabelCase(), value: $0) }
}
return !parameters.isEmpty ? .init(taskLocalName: name, taskLocalParams: parameters) : nil
}

static func getTrait<T: StaticTrait>(from shape: SotoSmithy.Shape, trait: T.Type, id: ShapeId) throws -> T {
guard let trait = shape.trait(type: T.self) else {
throw Error(reason: "\(id) does not have a \(T.staticName) trait")
Expand Down Expand Up @@ -353,7 +376,8 @@ struct AwsService {
if self.outputHTMLComments {
docs = documentation?.split(separator: "\n") ?? []
} else {
docs = documentation?
docs =
documentation?
.tagStriped()
.replacingOccurrences(of: "\n +", with: " ", options: .regularExpression, range: nil)
.split(separator: "\n")
Expand All @@ -375,7 +399,8 @@ struct AwsService {
if let recommendation = shape.trait(type: RecommendedTrait.self)?.reason {
documentation += "\n\(recommendation)"
}
return documentation
return
documentation
.tagStriped()
.replacingOccurrences(of: "\n +", with: " ", options: .regularExpression, range: nil)
.split(separator: "\n")
Expand All @@ -385,7 +410,7 @@ struct AwsService {

/// process documentation string
func processDocs(_ documentation: String?) -> [String.SubSequence] {
return documentation?
documentation?
.tagStriped()
.replacingOccurrences(of: "\n +", with: " ", options: .regularExpression, range: nil)
.split(separator: "\n")
Expand All @@ -399,11 +424,11 @@ struct AwsService {
return "AWSEditHeadersMiddleware(.add(name: \"accept\", value: \"application/json\"))"
case "Glacier":
return """
AWSMiddlewareStack {
AWSEditHeadersMiddleware(.add(name: \"x-amz-glacier-version\", value: \"\(service.version ?? "2012-06-01")\"))
TreeHashMiddleware(header: \"x-amz-sha256-tree-hash\")
}
"""
AWSMiddlewareStack {
AWSEditHeadersMiddleware(.add(name: \"x-amz-glacier-version\", value: \"\(service.version ?? "2012-06-01")\"))
TreeHashMiddleware(header: \"x-amz-sha256-tree-hash\")
}
"""
case "S3":
return "S3Middleware()"
default:
Expand All @@ -420,7 +445,7 @@ struct AwsService {
}

func encodingName(_ name: String) -> String {
return "_\(name)Encoding"
"_\(name)Encoding"
}

/// return payload member of structure
Expand Down Expand Up @@ -521,9 +546,10 @@ struct AwsService {
/// The JSON decoder requires an array to exist, even if it is empty so we have to make
/// all arrays in output shapes optional
func removeRequiredTraitFromOutputCollections(_ model: Model) {
guard self.serviceProtocolTrait is AwsProtocolsAwsJson1_0Trait ||
self.serviceProtocolTrait is AwsProtocolsAwsJson1_1Trait ||
self.serviceProtocolTrait is AwsProtocolsRestJson1Trait else { return }
guard
self.serviceProtocolTrait is AwsProtocolsAwsJson1_0Trait || self.serviceProtocolTrait is AwsProtocolsAwsJson1_1Trait
|| self.serviceProtocolTrait is AwsProtocolsRestJson1Trait
else { return }

for shape in model.shapes {
guard shape.value.hasTrait(type: SotoOutputShapeTrait.self) else { continue }
Expand Down Expand Up @@ -556,8 +582,9 @@ struct AwsService {
}
// if output token is member of an optional struct add ? suffix
if let member = structure.members?[String(split[0])] {
let required = member.hasTrait(type: RequiredTrait.self) ||
(member.hasTrait(type: HttpPayloadTrait.self) && structure.hasTrait(type: SotoOutputShapeTrait.self))
let required =
member.hasTrait(type: RequiredTrait.self)
|| (member.hasTrait(type: HttpPayloadTrait.self) && structure.hasTrait(type: SotoOutputShapeTrait.self))
if !required, split.count > 1 {
split[0] += "?"
}
Expand Down Expand Up @@ -603,13 +630,17 @@ struct AwsService {
guard let service = $0.services[self.serviceEndpointPrefix] else { return }
guard let partitionEndpoint = service.partitionEndpoint else { return }
guard let endpoint = service.endpoints[partitionEndpoint] else {
self.logger.error("Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) does not exist")
self.logger.error(
"Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) does not exist"
)
return
}
guard let region = endpoint.credentialScope?.region else {
// services with SigV4 authentication require an endpoint
if self.service.trait(type: AwsAuthSigV4Trait.self) != nil {
self.logger.error("Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) has no credential scope region")
self.logger.error(
"Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) has no credential scope region"
)
}
return
}
Expand Down Expand Up @@ -637,17 +668,24 @@ struct AwsService {
.sorted()
.joined(separator: ", ")
// get dnsSuffix for this variant
guard let dnsSuffix = getDefaultValue(partition: partition, service: service, getValue: { defaults in
return defaults.variants?.first(where: { $0.tags == variant.tags })?.dnsSuffix
}) else {
guard
let dnsSuffix = getDefaultValue(
partition: partition,
service: service,
getValue: { defaults in
defaults.variants?.first(where: { $0.tags == variant.tags })?.dnsSuffix
}
)
else {
continue
}
if variantEndpoints[variantString] == nil {
variantEndpoints[variantString] = .init()
}
if let hostname = variant.hostname {
// get hostname and replace any variables (wrapped in {}) in hostname
let finalHostname = hostname
let finalHostname =
hostname
.replacingOccurrences(of: "{region}", with: endpoint.key)
.replacingOccurrences(of: "{dnsSuffix}", with: dnsSuffix)
.replacingOccurrences(of: "{service}", with: self.serviceEndpointPrefix)
Expand All @@ -659,7 +697,7 @@ struct AwsService {
}
// return variants with endpoints sorted by region name
return variantEndpoints.mapValues {
return .init(defaultEndpoint: $0.defaultEndpoint, endpoints: $0.endpoints.sorted { $0.region < $1.region })
.init(defaultEndpoint: $0.defaultEndpoint, endpoints: $0.endpoints.sorted { $0.region < $1.region })
}
}

Expand Down Expand Up @@ -695,12 +733,9 @@ struct AwsService {
}

func isMemberInBody(_ member: MemberShape, isOutputShape: Bool) -> Bool {
return !(member.hasTrait(type: HttpHeaderTrait.self) ||
member.hasTrait(type: HttpPrefixHeadersTrait.self) ||
(member.hasTrait(type: HttpQueryTrait.self) && !isOutputShape) ||
member.hasTrait(type: HttpQueryParamsTrait.self) ||
member.hasTrait(type: HttpLabelTrait.self) ||
member.hasTrait(type: HttpResponseCodeTrait.self))
!(member.hasTrait(type: HttpHeaderTrait.self) || member.hasTrait(type: HttpPrefixHeadersTrait.self)
|| (member.hasTrait(type: HttpQueryTrait.self) && !isOutputShape) || member.hasTrait(type: HttpQueryParamsTrait.self)
|| member.hasTrait(type: HttpLabelTrait.self) || member.hasTrait(type: HttpResponseCodeTrait.self))
}
}

Expand All @@ -711,11 +746,19 @@ extension AwsService {
let reason: String
}

struct TaskLocalParameters {
struct Parameter {
let key: String
let value: Any
}
let taskLocalName: String
let taskLocalParams: [Parameter]
}

struct OperationContext {
struct DiscoverableEndpoint {
let required: Bool
}

let comment: [String.SubSequence]
let funcName: String
let inputShape: String?
Expand All @@ -729,6 +772,7 @@ extension AwsService {
let documentationUrl: String?
let endpointRequired: DiscoverableEndpoint?
var initParameters: [OperationInitParamContext]
let taskLocals: TaskLocalParameters?
}

struct OperationInitParamContext {
Expand Down Expand Up @@ -955,6 +999,6 @@ extension AwsService {
case jmesAllPath(path: String, expected: String)
case error(String)
case errorStatus(Int)
case success(Int) // Success requires a dummy associated value, so a mustache context is created for the `MatcherContext`
case success(Int) // Success requires a dummy associated value, so a mustache context is created for the `MatcherContext`
}
}
Loading
Loading