@@ -368,7 +368,21 @@ SimpleMathTests.test("ForceUnwrapping") {
368
368
expectEqual ( ( 1 , 2 ) , forceUnwrap ( Float ( 2 ) ) )
369
369
}
370
370
371
- // CHECK-LABEL: sil private [ossa] @AD__${{.*}}jumpTimesTwo{{.*}}pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__$s4nullyycfU18_12jumpTimesTwoL_5modelSfAAyycfU18_14SmallTestModelL_V_tF_bb0__PB__src_0_wrt_0) -> SmallTestModel.TangentVector {
371
+ SimpleMathTests . test ( " Adjoint value accumulation for aggregate lhs and concrete rhs " ) {
372
+ // TF-943: Test adjoint value accumulation for aggregate lhs and concrete rhs.
373
+ struct SmallTestModel : Differentiable {
374
+ public var stored : Float = 3.0
375
+ @differentiable public func callAsFunction( ) -> Float { return stored }
376
+ }
377
+
378
+ func doubled( _ model: SmallTestModel ) -> Float {
379
+ return model ( ) + model. stored
380
+ }
381
+ let grads = gradient ( at: SmallTestModel ( ) , in: doubled)
382
+ expectEqual ( 2.0 , grads. stored)
383
+ }
384
+
385
+ // CHECK-LABEL: sil private [ossa] @AD__${{.*}}doubled{{.*}}pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned {{.*}}) -> SmallTestModel.TangentVector {
372
386
// CHECK: bb0([[DX:%.*]] : $Float, [[PB_STRUCT:%.*]] : {{.*}}):
373
387
// CHECK: ([[PB0:%.*]], [[PB1:%.*]]) = destructure_struct [[PB_STRUCT]]
374
388
// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float)
@@ -387,18 +401,4 @@ SimpleMathTests.test("ForceUnwrapping") {
387
401
// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
388
402
// CHECK: }
389
403
390
- SimpleMathTests . test ( " Struct " ) {
391
- // TF-943: Test adjoint value accumulation for aggregate lhs and concrete rhs.
392
- struct SmallTestModel : Differentiable {
393
- public var jump : Float = 3.0
394
- @differentiable public func callAsFunction( ) -> Float { return jump }
395
- }
396
-
397
- func jumpTimesTwo( model: SmallTestModel ) -> Float {
398
- return model ( ) + model. jump
399
- }
400
- let grads = gradient ( at: SmallTestModel ( ) , in: jumpTimesTwo)
401
- expectEqual ( 2.0 , grads. jump)
402
- }
403
-
404
404
runAllTests ( )
0 commit comments