Skip to content

Commit d4ffb91

Browse files
committed
feat(ksp): add support for context receivers and enhance function type handling
- Implemented context receivers handling in KSTypeNameCross for function types. - Added context receiver extraction logic for complex types. - Updated test project to include context receiver tests. - Enabled Kotlin compiler options for context parameters.
1 parent d3768ed commit d4ffb91

File tree

4 files changed

+132
-33
lines changed

4 files changed

+132
-33
lines changed

codegentle-kotlin-ksp/src/main/kotlin/love/forte/codegentle/kotlin/ksp/KSTypeNameCross.kt

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,35 @@ private fun KSType.toFunctionTypeName(isSuspend: Boolean): KotlinLambdaTypeName
137137
// For function types like (A, B) -> C, the arguments are [A, B, C] where C is the return type
138138
// For extension function types like T.() -> R, the arguments are [T, R] where T is the receiver
139139
// and R is the return type. Extension function types are marked with @kotlin.ExtensionFunctionType
140+
// For context function types like context(C1, C2) T.(P1) -> R, the arguments are [C1, C2, T, P1, R]
141+
// where C1, C2 are context receivers, T is the receiver, P1 is the parameter, and R is the return type
140142
val typeArgs = arguments.map { arg ->
141143
arg.toTypeRef()
142144
}
143145

144-
// Check if this is an extension function type (has a receiver)
146+
// check context receivers(via @ContextFunctionTypeParams)
147+
// TODO see https://github.com/google/ksp/issues/2702
148+
val contextReceiverCount = annotations.firstNotNullOfOrNull { annotation ->
149+
val annotationType = annotation.annotationType
150+
val annotationTypeValidated = annotation.annotationType.validate()
151+
152+
val isContextAnnotation = annotation.shortName.asString() == "ContextFunctionTypeParams" &&
153+
if (annotationTypeValidated) {
154+
annotationType.resolve().declaration.qualifiedName?.asString() == "kotlin.ContextFunctionTypeParams"
155+
} else {
156+
// handle ERROR TYPE situation
157+
annotationType.toString() == "<ERROR TYPE: kotlin.ContextFunctionTypeParams>"
158+
}
159+
160+
if (isContextAnnotation) {
161+
// annotation's first argument is the number of context receivers
162+
annotation.arguments.firstOrNull()?.value as? Int
163+
} else {
164+
null
165+
}
166+
} ?: 0
167+
168+
// check if this is an extension function type (has a receiver)
145169
val isExtensionFunctionType = annotations.any { annotation ->
146170
// annotation: @ExtensionFunctionType
147171
// annotation.annotationType: <ERROR TYPE: kotlin.ExtensionFunctionType>
@@ -156,43 +180,37 @@ private fun KSType.toFunctionTypeName(isSuspend: Boolean): KotlinLambdaTypeName
156180
if (annotationTypeValidated) {
157181
annotationType.resolve().declaration.qualifiedName?.asString() == "kotlin.ExtensionFunctionType"
158182
} else {
159-
// TODO how to do?
183+
// ERROR TYPE
160184
annotationType.toString() == "<ERROR TYPE: kotlin.ExtensionFunctionType>"
161185
}
162186
}
163187

164188
if (typeArgs.isNotEmpty()) {
165-
if (isExtensionFunctionType) {
166-
// For extension function types like T.(P1, P2) -> R:
167-
// - First argument is the receiver type (T)
168-
// - Middle arguments are parameter types (P1, P2)
169-
// - Last argument is the return type (R)
170-
val receiverType = typeArgs.first()
171-
receiver(receiverType)
172-
173-
// The last argument is the return type
174-
val returnType = typeArgs.last()
175-
returns(returnType)
176-
177-
// Middle arguments (if any) are parameter types
178-
for (index in 1 until typeArgs.lastIndex) {
179-
val paramType = typeArgs[index]
180-
addParameter(KotlinValueParameterSpec.builder("", paramType).build())
181-
}
182-
} else {
183-
// For regular function types like (A, B) -> C:
184-
// - All arguments except the last are parameter types
185-
// - Last argument is the return type
186-
val returnType = typeArgs.last()
187-
returns(returnType)
188-
189-
// All arguments except the last are parameter types
190-
for ((index, paramType) in typeArgs.withIndex()) {
191-
if (index != typeArgs.lastIndex) {
192-
addParameter(KotlinValueParameterSpec.builder("", paramType).build())
193-
}
189+
var currentIndex = 0
190+
191+
// extract context receivers (if any)
192+
if (contextReceiverCount > 0) {
193+
repeat(contextReceiverCount) {
194+
addContextReceiver(typeArgs[currentIndex])
195+
currentIndex++
194196
}
195197
}
198+
199+
// extract receiver (if any)
200+
if (isExtensionFunctionType) {
201+
receiver(typeArgs[currentIndex])
202+
currentIndex++
203+
}
204+
205+
// last is return type
206+
val returnType = typeArgs.last()
207+
returns(returnType)
208+
209+
// middle is parameter type
210+
for (index in currentIndex until typeArgs.lastIndex) {
211+
val paramType = typeArgs[index]
212+
addParameter(KotlinValueParameterSpec.builder("", paramType).build())
213+
}
196214
}
197215
}
198216
}

tests/test-ksp-receiver-and-contexts/proc/src/main/kotlin/test/ksp/GenerateBackupProcessor.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ class GenerateBackupProcessor(
4848

4949
try {
5050
// 使用 codegentle-kotlin-ksp 将 KSP 函数转换为 KotlinFunctionSpec
51+
println("original function: $function")
52+
function.parameters.forEach {
53+
println("\toriginal function parameter: $it")
54+
println("\toriginal function parameter.name: ${it.name?.asString()}")
55+
println("\toriginal function parameter.type: ${it.type}")
56+
println("\toriginal function parameter.type.annotations: ${it.type.annotations.toList()}")
57+
println("\toriginal function parameter.type.annotations: ${it.type.annotations.map { a -> a.annotationType }.toList()}")
58+
println("\toriginal function parameter.type.annotations.resolve(): ${it.type.resolve()}")
59+
println("\toriginal function parameter.type.annotations.resolve().annotations: ${it.type.resolve().annotations.toList()}")
60+
}
5161
val originalSpec = function.toKotlinFunctionSpec()
5262

5363
// 创建备份函数 - 使用新名字重建

tests/test-ksp-receiver-and-contexts/proj/build.gradle.kts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@ plugins {
55

66
dependencies {
77
// 依赖处理器模块 - 注解定义
8-
implementation(project(":tests:test-ksp:proc"))
8+
implementation(project(":tests:test-ksp-receiver-and-contexts:proc"))
99

1010
// KSP 处理器
11-
ksp(project(":tests:test-ksp:proc"))
11+
ksp(project(":tests:test-ksp-receiver-and-contexts:proc"))
1212

1313
testImplementation(kotlin("test"))
1414
}
1515

1616
kotlin {
1717
jvmToolchain(11)
18+
19+
// 启用 context receivers 特性
20+
compilerOptions {
21+
freeCompilerArgs.add("-Xcontext-parameters")
22+
}
1823
}
1924

2025
tasks.withType<Test> {

tests/test-ksp-receiver-and-contexts/proj/src/main/kotlin/test/app/TestFunctions.kt

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,69 @@ fun multiParamReceiver(block: String.(Int, Boolean) -> Unit) {
4949
suspend fun suspendWithReceiver(block: suspend String.() -> Unit) {
5050
"Suspend".block()
5151
}
52+
53+
/**
54+
* 测试带有单个 context receiver 的函数类型。
55+
* 参数类型:context(String) () -> Unit
56+
*/
57+
@GenerateBackup
58+
fun runWithContext(block: context(String) () -> Unit) {
59+
context("Context") {
60+
block()
61+
}
62+
}
63+
64+
/**
65+
* 测试带有 context receiver 和扩展函数类型的组合。
66+
* 参数类型:context(String) Int.() -> Unit
67+
*/
68+
@GenerateBackup
69+
fun runWithContextAndReceiver(block: context(String) Int.() -> Unit) {
70+
with("Context") {
71+
42.block()
72+
}
73+
}
74+
75+
/**
76+
* 测试带有多个 context receivers 的函数类型。
77+
* 参数类型:context(String, Int) () -> Unit
78+
*/
79+
@GenerateBackup
80+
fun runWithMultipleContexts(block: context(String, Int) () -> Unit) {
81+
context("Context", 42) {
82+
block()
83+
}
84+
}
85+
86+
/**
87+
* 测试带有 context receivers、扩展 receiver 和参数的复杂函数类型。
88+
* 参数类型:context(String, Int) StringBuilder.(Boolean) -> Unit
89+
*/
90+
@GenerateBackup
91+
fun complexContextFunction(block: context(String, Int) StringBuilder.(Boolean) -> Unit) {
92+
context("Context", 42) {
93+
StringBuilder().block(true)
94+
}
95+
}
96+
97+
/**
98+
* 测试带有 context receiver 的 suspend 函数类型。
99+
* 参数类型:context(String) suspend () -> Unit
100+
*/
101+
@GenerateBackup
102+
suspend fun suspendWithContext(block: suspend context(String) () -> Unit) {
103+
context("Context") {
104+
block()
105+
}
106+
}
107+
108+
/**
109+
* 测试带有 context receiver 的 suspend 函数类型。
110+
* 参数类型:context(String) suspend () -> Unit
111+
*/
112+
@GenerateBackup
113+
suspend fun suspendWithContextAndReceiver(block: suspend context(String) StringBuilder.() -> Unit) {
114+
context("Context") {
115+
StringBuilder().block()
116+
}
117+
}

0 commit comments

Comments
 (0)