Skip to content

Explain the Abstract Primals Problem #343

Open
@oxinabox

Description

@oxinabox

There is a problem that comes up to do with abstract primals.
Most commonly in the case of AbstractArrays.
We don't have a good explination of what the problem is anywhere, it is scattered across various issues and PRs over various issues.
@willtebbutt has spent a bunch of time thinking about it.

I propose that we should open a docs PR that clearly explains it, with examples etc.
As part of the design docs section.
Once we have that PR open, we can talk more about solving it.

@mcabbott I were discussing this on slack. (they might post there note on this later)
The below is roughly extracted from that It kind of discusses a lot of the problem, thought it isn't super clear.
Since the whole discussion exists because we don't have a clear eplination of what the problem is.


Rough ugly notes:

The problem is that we want to define rules not just for Arrays but also for StaticArrays.
Which nearest common super type is AbstractArray.
But if you do that then people say this is bad because it will mean Diagonal will take O(N^2) rather than O(N).
One could say that the user op'ed into this, since they used a AbstractMatrix and so it is on the method author (in this case the rule author) to provide an optimized method if appropriate.

But there is a greater problem.
IIRC some operations on FillArray give the wrong answer, not just the wrong time complexity, if you treat it as a generic Array.
Mike and Will T had a big argument about it.

However:
If you only define rules on fundermental array types like Array and maybe StaticArray and GPUArrays, you get the correct time complexity, and you avoid defining things so generally that they break weirder array types.
And the AD will decompose the wrapper arrays correctly and get what is inside.
FillArrays work correctly if you treat them like structs. (As do all wrapper arrays).
@willtebbutt has spent ages thinking about this.
o understand the problem though I think we still don't have a great solution.
I am hoping Will will write something that we can put in the docs. (We have a few nice writeups like that on there)
I think we might actually need to formally introduce the idea of a fundermental array type as a trait maybe into ChainRules.
Or maybe just add a StaticArrays dep (it's basically a stdlib at this point) and then we can have a union for them

Now defining things only on concrete types seems unidiomatic.
Julia code normally is define on the general case then multiple dispatch is used to provide specific optizations for the specific case
However, AD has 2 ways to achieve functionality. Rule and Generation. And we want to endure overall most specific gets hit.
Julia should work with the most specific function takes precedence over the less specific function right?
The most specific function is the one that is most specific and customised for that type.
With AD the most specific functionality is always available: it is to let the AD run, which generates code for the exact input.
So when a rule is applied it is actually getting in the way of the most specific type.
E.g. when a rule is defined for AbstractMatrix that prevents the AD from generating the more specific functionality for Symmeteic{Diagonal{...}} .

When I say wrapper arrays I mean anything that has a parent array. (According to the parent function)
Though it really is more general: it is anything which has the method underconsideration defined without resorting to ccall.
Since as long as it doesn't resort to ccall the AD will be able to generate a pullback for the method.
That generated pullback will call something that we will have an optimised rule for. Might not even be the same function but it will be something, and so we are solid.
Remember AD systems do generally generate code that is optimal if they don't error, except if there is specific domain expertise the rule author is applying.
So we only need the rule to catch the errorring case.

Julia's specialisation rules do apply to AD rules and to code generated by AD. But the AD doesn't get to generate it's code if it hits a rule.
And the code will be more specialised than something hitting abstract matrix. This is the thing where the AD on Diagonal would get to break things down according to the primal method definition and would end up hitting a rule for Vector, rather than for matrix.
Where as the compiler's specialisation of a rrule for AbstractMatrix's is not so specialised and ends up looking vat a bunch of zero elements.
The AD would have done the better thing if the rule hadn't have been defined.
Because AD systems are good at findings derivatives.

Optimal as in identical to the code someone would write for this concrete input by hand.
AD doesn't always do it, because sometimes there is domain knowledge to apply. But in simple cases like decomposing function on wrapper arrays, it does.

As I said: if you have domain knowledge then you can do better.
But still generally that domain knowledge will be able to be applied to a "fundermental" array type (,like Array, and maybe StaticArray, GPUArray) and it woll still end up benifitting the wrapper arrays type.
and the generated code the comes for the pullback between the wrapper type and the type that has the domain knowledge rule for will be optimal (in the sense of being basically identical to what a human would do to do this).

Metadata

Metadata

Assignees

Labels

documentationImprovements or additions to documentation

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions