Skip to content

Commit bcae501

Browse files
authored
[AutoDiff] NFC: garden test. (swiftlang#32209)
Clarify test name and contents.
1 parent d3b6b89 commit bcae501

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

test/AutoDiff/validation-test/simple_math.swift

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,21 @@ SimpleMathTests.test("ForceUnwrapping") {
368368
expectEqual((1, 2), forceUnwrap(Float(2)))
369369
}
370370

371-
// CHECK-LABEL: sil private [ossa] @AD__${{.*}}jumpTimesTwo{{.*}}pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__$s4nullyycfU18_12jumpTimesTwoL_5modelSfAAyycfU18_14SmallTestModelL_V_tF_bb0__PB__src_0_wrt_0) -> SmallTestModel.TangentVector {
371+
SimpleMathTests.test("Adjoint value accumulation for aggregate lhs and concrete rhs") {
372+
// TF-943: Test adjoint value accumulation for aggregate lhs and concrete rhs.
373+
struct SmallTestModel : Differentiable {
374+
public var stored: Float = 3.0
375+
@differentiable public func callAsFunction() -> Float { return stored }
376+
}
377+
378+
func doubled(_ model: SmallTestModel) -> Float{
379+
return model() + model.stored
380+
}
381+
let grads = gradient(at: SmallTestModel(), in: doubled)
382+
expectEqual(2.0, grads.stored)
383+
}
384+
385+
// CHECK-LABEL: sil private [ossa] @AD__${{.*}}doubled{{.*}}pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned {{.*}}) -> SmallTestModel.TangentVector {
372386
// CHECK: bb0([[DX:%.*]] : $Float, [[PB_STRUCT:%.*]] : {{.*}}):
373387
// CHECK: ([[PB0:%.*]], [[PB1:%.*]]) = destructure_struct [[PB_STRUCT]]
374388
// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float)
@@ -387,18 +401,4 @@ SimpleMathTests.test("ForceUnwrapping") {
387401
// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
388402
// CHECK: }
389403

390-
SimpleMathTests.test("Struct") {
391-
// TF-943: Test adjoint value accumulation for aggregate lhs and concrete rhs.
392-
struct SmallTestModel : Differentiable {
393-
public var jump: Float = 3.0
394-
@differentiable public func callAsFunction() -> Float { return jump }
395-
}
396-
397-
func jumpTimesTwo(model: SmallTestModel) -> Float{
398-
return model() + model.jump
399-
}
400-
let grads = gradient(at: SmallTestModel(), in: jumpTimesTwo)
401-
expectEqual(2.0, grads.jump)
402-
}
403-
404404
runAllTests()

0 commit comments

Comments
 (0)