-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable the use of structured gradients for Zygote
#62
Conversation
- Internally always use `Optimisers.destructure` - Use structured gradients for Zygote - Don't use `DiffResults` and just pass around gradients
I'm not so certain this is true? There are many "large-scale" cases where ReverseDiff.jl beats Zygote.jl easily. Similarly, we have the "up-and-coming" ADs like Enzyme.jl and Tapir.jl which both promises further perf gains beyond Zygote.jl. IMO it seems a bit premature to go down the Zygote-friendly route 😕 |
|
||
maybe_destructure(::ADTypes.AutoZygote, q) = (q, identity) | ||
|
||
maybe_destructure(::ADTypes.AbstractADType, q) = Optimisers.destructure(q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The annoyance with using dispatch to do the destructuring is that you now need to define new structs for every type of parameterization of a distribution you want to do.
As in, how do you separate between, say, a MvNormal
with a diag and dense covariance matrix here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intent here was to determine whether to use destruct or not at all depending on the ADType.
@torfjelde I agree. Let's not do this. |
This PR restructures the project so that
Zygote
can use structured gradients without having to flatten everything. Also, the interface forces the use ofOptimiers.destructure
, unlike the previous version where we exposed some control over this. Overall, the changes are summarized as follows:Zygote
now can use structured gradients directly, without having to flatten the parameters.optimize
interface is simpler as we do not expose control overrestructure
(we will always useOpimisers.destructure
internally)DiffResults
anymore and just pass around gradients directly. This makes most of the package immutable which is nice, but could impact memory usage/GC time. However, in large-scale problems,Zygote
will be the only option, so I think this is okay in the sense of prioritizingZygote
-friendliness.Any concerns/comments would be much appreciated!
Also, sorry that the diff is overlapping with #61 !