@@ -93,24 +93,45 @@ enum UnsafePointerKind {
93
93
case Mutable
94
94
}
95
95
96
+ func getTypeName( _ type: TypeSyntax ) throws -> TokenSyntax {
97
+ switch type. kind {
98
+ case . memberType:
99
+ let memberType = type. as ( MemberTypeSyntax . self) !
100
+ if !memberType. baseType. isSwiftCoreModule {
101
+ throw DiagnosticError ( " expected pointer type in Swift core module, got type \( type) with base type \( memberType. baseType) " , node: type)
102
+ }
103
+ return memberType. name
104
+ case . identifierType:
105
+ return type. as ( IdentifierTypeSyntax . self) !. name
106
+ default :
107
+ throw DiagnosticError ( " expected pointer type, got \( type) with kind \( type. kind) " , node: type)
108
+ }
109
+ }
110
+
111
+ func replaceTypeName( _ type: TypeSyntax , _ name: TokenSyntax ) -> TypeSyntax {
112
+ if let memberType = type. as ( MemberTypeSyntax . self) {
113
+ return TypeSyntax ( memberType. with ( \. name, name) )
114
+ }
115
+ let idType = type. as ( IdentifierTypeSyntax . self) !
116
+ return TypeSyntax ( idType. with ( \. name, name) )
117
+ }
118
+
96
119
func transformType( _ prev: TypeSyntax , _ variant: Variant , _ isSizedBy: Bool ) throws -> TypeSyntax {
97
120
if let optType = prev. as ( OptionalTypeSyntax . self) {
98
121
return TypeSyntax ( optType. with ( \. wrappedType, try transformType ( optType. wrappedType, variant, isSizedBy) ) )
99
122
}
100
123
if let impOptType = prev. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
101
124
return try transformType ( impOptType. wrappedType, variant, isSizedBy)
102
125
}
103
- guard let idType = prev. as ( IdentifierTypeSyntax . self) else {
104
- throw DiagnosticError ( " expected pointer type, got \( prev) with kind \( prev. kind) " , node: prev)
105
- }
106
- let text = idType. name. text
126
+ let name = try getTypeName ( prev)
127
+ let text = name. text
107
128
let kind : UnsafePointerKind = switch text {
108
129
case " UnsafePointer " : . Immutable
109
130
case " UnsafeMutablePointer " : . Mutable
110
131
case " UnsafeRawPointer " : . Immutable
111
132
case " UnsafeMutableRawPointer " : . Mutable
112
133
default : throw DiagnosticError ( " expected Unsafe[Mutable][Raw]Pointer type for type \( prev) " +
113
- " - first type token is ' \( text) ' " , node: idType . name)
134
+ " - first type token is ' \( text) ' " , node: name)
114
135
}
115
136
if isSizedBy {
116
137
let token : TokenSyntax = switch ( kind, variant. generateSpan) {
@@ -122,15 +143,15 @@ func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) th
122
143
return TypeSyntax ( IdentifierTypeSyntax ( name: token) )
123
144
}
124
145
if text == " UnsafeRawPointer " || text == " UnsafeMutableRawPointer " {
125
- throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: idType . name)
146
+ throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
126
147
}
127
148
let token : TokenSyntax = switch ( kind, variant. generateSpan) {
128
149
case ( . Immutable, true ) : " Span "
129
150
case ( . Mutable, true ) : " MutableSpan "
130
151
case ( . Immutable, false ) : " UnsafeBufferPointer "
131
152
case ( . Mutable, false ) : " UnsafeMutableBufferPointer "
132
153
}
133
- return TypeSyntax ( idType . with ( \ . name , token) )
154
+ return replaceTypeName ( prev , token)
134
155
}
135
156
136
157
struct Variant {
@@ -209,11 +230,6 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
209
230
}
210
231
}
211
232
212
- func hasReturnType( _ signature: FunctionSignatureSyntax ) -> Bool {
213
- let returnType = signature. returnClause? . type. as ( IdentifierTypeSyntax . self) ? . name. text ?? " Void "
214
- return returnType != " Void "
215
- }
216
-
217
233
protocol PointerBoundsThunkBuilder : BoundsCheckedThunkBuilder {
218
234
var name : TokenSyntax { get }
219
235
var nullable : Bool { get }
@@ -274,8 +290,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
274
290
}
275
291
276
292
func castIntToTargetType( expr: ExprSyntax , type: TypeSyntax ) -> ExprSyntax {
277
- let idType = type. as ( IdentifierTypeSyntax . self) !
278
- if idType. name. text == " Int " {
293
+ if type. isSwiftInt {
279
294
return expr
280
295
}
281
296
return ExprSyntax ( " \( type) (exactly: \( expr) )! " )
@@ -290,15 +305,10 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
290
305
let call = try base. buildFunctionCall ( args, variant)
291
306
let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: name) ) )
292
307
293
- let returnKw : String = if hasReturnType ( signature) {
294
- " return "
295
- } else {
296
- " "
297
- }
298
308
let funcName = isSizedBy ? " withUnsafeBytes " : " withUnsafeBufferPointer "
299
309
let unwrappedCall = ExprSyntax ( """
300
310
\( ptrRef) . \( raw: funcName) { \( unwrappedName) in
301
- \( raw : returnKw ) \( call)
311
+ return \( call)
302
312
}
303
313
""" )
304
314
return unwrappedCall
@@ -419,7 +429,7 @@ public struct PointerBoundsMacro: PeerMacro {
419
429
guard let intLiteral = expr. as ( IntegerLiteralExprSyntax . self) else {
420
430
throw DiagnosticError ( " expected integer literal, got ' \( expr) ' " , node: expr)
421
431
}
422
- guard let res = Int ( intLiteral. literal . text ) else {
432
+ guard let res = intLiteral. representedLiteralValue else {
423
433
throw DiagnosticError ( " expected integer literal, got ' \( expr) ' " , node: expr)
424
434
}
425
435
return res
@@ -429,10 +439,14 @@ public struct PointerBoundsMacro: PeerMacro {
429
439
guard let boolLiteral = expr. as ( BooleanLiteralExprSyntax . self) else {
430
440
throw DiagnosticError ( " expected boolean literal, got ' \( expr) ' " , node: expr)
431
441
}
432
- guard let res = Bool ( boolLiteral. literal. text) else {
442
+ switch boolLiteral. literal. tokenKind {
443
+ case . keyword( . true ) :
444
+ return true
445
+ case . keyword( . false ) :
446
+ return false
447
+ default :
433
448
throw DiagnosticError ( " expected bool literal, got ' \( expr) ' " , node: expr)
434
449
}
435
- return res
436
450
}
437
451
438
452
static func parseCountedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax , _ signature: FunctionSignatureSyntax ) throws -> ParamInfo {
@@ -590,12 +604,8 @@ public struct PointerBoundsMacro: PeerMacro {
590
604
CodeBlockItemSyntax ( leadingTrivia: " \n " , item: e)
591
605
}
592
606
}
593
- let call = if hasReturnType ( funcDecl. signature) {
594
- CodeBlockItemSyntax ( item: CodeBlockItemSyntax . Item ( ReturnStmtSyntax ( returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
607
+ let call = CodeBlockItemSyntax ( item: CodeBlockItemSyntax . Item ( ReturnStmtSyntax ( returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
595
608
expression: try builder. buildFunctionCall ( [ : ] , variant) ) ) )
596
- } else {
597
- CodeBlockItemSyntax ( item: CodeBlockItemSyntax . Item ( try builder. buildFunctionCall ( [ : ] , variant) ) )
598
- }
599
609
let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
600
610
let newFunc = funcDecl
601
611
. with ( \. signature, newSignature)
0 commit comments