Skip to content

Commit 1e403ec

Browse files
authored
[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions (swiftlang#78908)
Consider an `@_alwaysEmitIntoClient` function and a custom derivative defined for it. Previously, such a combination resulted different errors under different circumstances. Sometimes, there were linker errors due to missing derivative function symbol - these occurred when we tried to find the derivative in a module, while it should have been emitted into client's code (and it did not happen). Sometimes, there were SIL verification failures like this: ``` SIL verification failed: internal/private function cannot be serialized or serializable: !F->isAnySerialized() || embedded ``` Linkage and serialization options for the derivative were not handled properly, and, instead of PublicNonABI linkage, we had Private one which is unsupported for serialization - but we need to serialize `@_alwaysEmitIntoClient` functions so the client's code is able to see them. This patch resolves the issue and adds proper handling of custom derivatives of `@_alwaysEmitIntoClient` functions. Note that either both the function and its custom derivative or none of them should have `@_alwaysEmitIntoClient` attribute, mismatch in this attribute is not supported. The following cases are handled (assume that in each case client's code uses the derivative). 1. Both the function and its derivative are defined in a single file in one module. 2. Both the function and its derivative are defined in different files which are compiled to a single module. 3. The function is defined in one module, its derivative is defined in another module. 4. The function and the derivative are defined as members of a protocol extension in two separate modules - one for the function and one for the derivative. A struct conforming the protocol is defined in the third module. 5. The function and the derivative are defined as members of a struct extension in two separate modules - one for the function and one for the derivative. The changes allow to define derivatives for methods of `SIMD`. Fixes swiftlang#54445 <!-- If this pull request is targeting a release branch, please fill out the following form: https://github.com/swiftlang/.github/blob/main/PULL_REQUEST_TEMPLATE/release.md?plain=1 Otherwise, replace this comment with a description of your changes and rationale. Provide links to external references/discussions if appropriate. If this pull request resolves any GitHub issues, link them like so: Resolves <link to issue>, resolves <link to another issue>. For more information about linking a pull request to an issue, see: https://docs.github.com/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue --> <!-- Before merging this pull request, you must run the Swift continuous integration tests. For information about triggering CI builds via @swift-ci, see: https://github.com/apple/swift/blob/main/docs/ContinuousIntegration.md#swift-ci Thank you for your contribution to Swift! -->
1 parent 29a9fb0 commit 1e403ec

File tree

30 files changed

+476
-31
lines changed

30 files changed

+476
-31
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4387,6 +4387,9 @@ NOTE(derivative_attr_fix_access,none,
43874387
"mark the derivative function as "
43884388
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
43894389
"to match the original function", (AccessLevel))
4390+
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
4391+
"either both or none of derivative and original function must have "
4392+
"@alwaysEmitIntoClient attribute", ())
43904393
ERROR(derivative_attr_static_method_mismatch_original,none,
43914394
"unexpected derivative function declaration; "
43924395
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",

lib/SIL/IR/Linker.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,23 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
159159
// HiddenExternal linkage when they are declarations, then they
160160
// become Shared after the body has been deserialized.
161161
// So try deserializing HiddenExternal functions too.
162-
if (linkage == SILLinkage::HiddenExternal)
163-
return deserializeAndPushToWorklist(F);
164-
162+
if (linkage == SILLinkage::HiddenExternal) {
163+
deserializeAndPushToWorklist(F);
164+
if (!F->markedAsAlwaysEmitIntoClient())
165+
return;
166+
// For @_alwaysEmitIntoClient functions, we need to lookup its
167+
// differentiability witness and, if present, ask SILLoader to obtain its
168+
// definition. Otherwise, a linker error would occur due to undefined
169+
// reference to these symbols.
170+
for (SILDifferentiabilityWitness *witness :
171+
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
172+
F->getName())) {
173+
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
174+
witness->getKey());
175+
}
176+
return;
177+
}
178+
165179
// Update the linkage of the function in case it's different in the serialized
166180
// SIL than derived from the AST. This can be the case with cross-module-
167181
// optimizations.

lib/SILGen/SILGen.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,14 +1435,19 @@ void SILGenModule::emitDifferentiabilityWitness(
14351435
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
14361436
if (!diffWitness) {
14371437
// Differentiability witnesses have the same linkage as the original
1438-
// function, stripping external.
1439-
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
1438+
// function, stripping external. For @_alwaysEmitIntoClient original
1439+
// functions, force PublicNonABI linkage of the differentiability witness so
1440+
// we can serialize it (the original function itself might be HiddenExternal
1441+
// in this case if we only have declaration without definition).
1442+
auto linkage =
1443+
originalFunction->markedAsAlwaysEmitIntoClient()
1444+
? SILLinkage::PublicNonABI
1445+
: stripExternalFromLinkage(originalFunction->getLinkage());
14401446
diffWitness = SILDifferentiabilityWitness::createDefinition(
14411447
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
14421448
silConfig.resultIndices, config.derivativeGenericSignature,
14431449
/*jvp*/ nullptr, /*vjp*/ nullptr,
1444-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
1445-
attr);
1450+
/*isSerialized*/ hasPublicVisibility(linkage), attr);
14461451
}
14471452

14481453
// Set derivative function in differentiability witness.

lib/SILGen/SILGenPoly.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6498,8 +6498,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
64986498
auto loc = customDerivativeFn->getLocation();
64996499
SILGenFunctionBuilder fb(*this);
65006500
// Derivative thunks have the same linkage as the original function, stripping
6501-
// external.
6502-
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
6501+
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
6502+
// linkage of derivative thunks so we can serialize them (the original
6503+
// function itself might be HiddenExternal in this case if we only have
6504+
// declaration without definition).
6505+
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
6506+
? SILLinkage::PublicNonABI
6507+
: stripExternalFromLinkage(originalFn->getLinkage());
6508+
65036509
auto *thunk = fb.getOrCreateFunction(
65046510
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
65056511
customDerivativeFn->getSerializedKind(),

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
538538
"definitions with explicit differentiable attributes");
539539

540540
return SILDifferentiabilityWitness::createDeclaration(
541-
module, SILLinkage::PublicExternal, original, kind,
542-
minimalConfig->parameterIndices, minimalConfig->resultIndices,
543-
minimalConfig->derivativeGenericSignature);
541+
module,
542+
// Witness for @_alwaysEmitIntoClient original function must be emitted,
543+
// otherwise a linker error would occur due to undefined reference to the
544+
// witness symbol.
545+
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
546+
: SILLinkage::PublicExternal,
547+
original, kind, minimalConfig->parameterIndices,
548+
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
544549
}
545550

546551
} // end namespace autodiff

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -999,10 +999,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
999999

10001000
// We can generate empty JVP / VJP for functions available externally. These
10011001
// functions have the same linkage as the original ones sans `external`
1002-
// flag. Important exception here hidden_external functions as they are
1003-
// serializable but corresponding hidden ones would be not and the SIL
1004-
// verifier will fail. Patch `serializeFunctions` for this case.
1005-
if (orig->getLinkage() == SILLinkage::HiddenExternal)
1002+
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
1003+
// functions as they are serializable but corresponding hidden ones would be
1004+
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
1005+
// case. For @_alwaysEmitIntoClient original functions (which might be
1006+
// HiddenExternal if we only have declaration without definition), we want
1007+
// derivatives to be serialized and do not patch `serializeFunctions`.
1008+
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
1009+
!orig->markedAsAlwaysEmitIntoClient())
10061010
serializeFunctions = IsNotSerialized;
10071011

10081012
// If the JVP doesn't exist, need to synthesize it.

lib/Sema/TypeCheckAttr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6990,6 +6990,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
69906990
return true;
69916991
}
69926992

6993+
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
6994+
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
6995+
diags.diagnose(derivative->getLoc(),
6996+
diag::derivative_attr_always_emit_into_client_mismatch);
6997+
return true;
6998+
}
6999+
69937000
// Get the resolved differentiability parameter indices.
69947001
auto *resolvedDiffParamIndices = attr->getParameterIndices();
69957002

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,6 @@ where
405405
}
406406
}
407407

408-
// FIXME(TF-1103): Derivative registration does not yet support
409-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
410-
/*
411408
extension SIMD
412409
where
413410
Self: Differentiable,
@@ -417,6 +414,7 @@ where
417414
TangentVector == Self
418415
{
419416
@inlinable
417+
@_alwaysEmitIntoClient
420418
@derivative(of: sum)
421419
func _vjpSum() -> (
422420
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
@@ -425,14 +423,14 @@ where
425423
}
426424

427425
@inlinable
426+
@_alwaysEmitIntoClient
428427
@derivative(of: sum)
429428
func _jvpSum() -> (
430429
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
431430
) {
432431
return (sum(), { v in Scalar.TangentVector(v.sum()) })
433432
}
434433
}
435-
*/
436434

437435
extension SIMD
438436
where

test/AutoDiff/SILGen/nil_coalescing.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
1+
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
2+
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
23

34
import _Differentiation
45

5-
// CHECK: sil @test_nil_coalescing
6+
// CHECK: sil non_abi @test_nil_coalescing
67
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
78
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
89
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
@@ -15,7 +16,7 @@ import _Differentiation
1516
//
1617
@_silgen_name("test_nil_coalescing")
1718
@derivative(of: ??)
18-
@usableFromInline
19+
@_alwaysEmitIntoClient
1920
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
2021
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
2122
{

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
10621062
fatalError()
10631063
}
10641064

1065+
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1066+
@_alwaysEmitIntoClient
1067+
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
1068+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1069+
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1070+
fatalError()
1071+
}
1072+
1073+
@_alwaysEmitIntoClient
10651074
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10661075
@_alwaysEmitIntoClient
10671076
@derivative(of: internal_original_alwaysemitintoclient_derivative)
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
10841093
fatalError()
10851094
}
10861095

1096+
@_alwaysEmitIntoClient
1097+
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1098+
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
1099+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1100+
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1101+
fatalError()
1102+
}
1103+
1104+
@_alwaysEmitIntoClient
10871105
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10881106
@_alwaysEmitIntoClient
10891107
@derivative(of: package_original_alwaysemitintoclient_derivative)

0 commit comments

Comments
 (0)