Skip to content

Commit cf937ce

Browse files
authored
Make pullback error for ColVecs and RowVecs a bit more informative (#523)
* make adjoint error message a bit more informative * Apply suggestions from code review * Update Project.toml * Update src/chainrules.jl
1 parent 5127a26 commit cf937ce

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.56"
3+
version = "0.10.57"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/chainrules.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
150150
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
151151
return error(
152152
"Pullback on AbstractVector{<:AbstractVector}.\n" *
153-
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
154-
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
153+
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" *
154+
"or because some external computation has acted on `ColVecs` to produce a vector of vectors." *
155+
"In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." *
156+
"In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," *
157+
" rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it."
155158
)
156159
end
157160
return ColVecs(X), ColVecs_pullback
@@ -162,8 +165,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
162165
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
163166
return error(
164167
"Pullback on AbstractVector{<:AbstractVector}.\n" *
165-
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
166-
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
168+
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" *
169+
"or because some external computation has acted on `RowVecs` to produce a vector of vectors." *
170+
"If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
167171
)
168172
end
169173
return RowVecs(X), RowVecs_pullback

0 commit comments

Comments
 (0)