@@ -150,8 +150,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
150
150
function ColVecs_pullback (:: AbstractVector{<:AbstractVector{<:Real}} )
151
151
return error (
152
152
" Pullback on AbstractVector{<:AbstractVector}.\n " *
153
- " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n " *
154
- " To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`" ,
153
+ " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n " *
154
+ " or because some external computation has acted on `ColVecs` to produce a vector of vectors." *
155
+ " In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." *
156
+ " In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," *
157
+ " rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it."
155
158
)
156
159
end
157
160
return ColVecs (X), ColVecs_pullback
@@ -162,8 +165,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
162
165
function RowVecs_pullback (:: AbstractVector{<:AbstractVector{<:Real}} )
163
166
return error (
164
167
" Pullback on AbstractVector{<:AbstractVector}.\n " *
165
- " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n " *
166
- " To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`" ,
168
+ " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n " *
169
+ " or because some external computation has acted on `RowVecs` to produce a vector of vectors." *
170
+ " If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`" ,
167
171
)
168
172
end
169
173
return RowVecs (X), RowVecs_pullback
0 commit comments