-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "low-rank" variational families #76
Conversation
Okay the tests didn't pass because the latest Enzyme patch somehow broke some stuff, but all tests passed locally. |
Interesting! I'll take a look later (probably going to be in the weekend). Sorry for the possible delay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Red-Portal! I was actually just yesterday thinking about something like MvLocationScale
, I hadn't realised it existed already. This low-rank version seems cool too.
I put a bunch of questions and proposals I had in local comments. It's all about code stuff, I don't have much expertise on the theory side here.
More broadly, and probably not to be addressed in this PR, but is there a reason to keep MvLocationScale
and MvLocationScaleLowRank
in AdvancedVI, rather than somewhere more centrally in TuringLang so that one could use them more broadly with Turing.jl, or maybe even in Distributions.jl?
I suspect the main reason @yebai tagged me as a reviewer though is the Enzyme failure. I'll look into it.
The Enzyme issue is this: The function
Since
which is just explicitly telling Enzyme that you, the caller, guarantee that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All make sense.
Is #75 solved by this? I saw it is referred by this PR, but not directed mentioned anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me apart from the bug where LocationScaleLowRank assumes zero mean and unit variance. I know you said you'd fix that in a separate PR, but I'm a bit wary of committing to master code that has a known, significant bug in it. Could we make it so that the code errors out if the assumption of normalisation is violated? I can just imagine a situation where this gets into master, someone tries to use it before the other PR gets done, and it silently gives wrong results.
[:meanfield, :fullrank], | ||
realtype in [Float32, Float64], | ||
bijector in [nothing, :identity] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably the formatter's work, but the indents here are quite odd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Red-Portal, and sorry for being a bit slow to respond. I had a few more questions. I'll also need to build and read the new docs still, haven't done that yet.
σ2 = var(q.dist) | ||
return σ2 * Hermitian(C * C') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know the theory here well, but is there a reason why this involves var(q.dist)
rather than cov(q.dist)
? I could have imagined it being something like C * cov(q.dist) * C'
, though that's just a not-very-educated guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I was thinking that q.dist
was constrained to be a univariate distribution, which would make all of this valid, but seems like I have to use ContinuousUnivariateDistribution
for that. Let me fix this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, yeah, this makes sense for univariate. Is there a reason you want to restrict to q.dist
being univariate? Just less of a headache to implement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it thought to be the easiest way to force people to provide a standardized isotropic distribution. We're not quite forcing it to be standardized, but at least this guarantees it is isotropic.
Docs look good. I still wonder if these distributions would be useful more broadly than just within AdvancedVI. |
@yebai The original formatter complaints were manual touches suggested by @mhauru because the formatted did a pretty ugly job in a few places. But I guess it will be hard to do manual formatting in the long run since people will just want to run the formatter without having to manually revert to the non-standard styles. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A test seems to fail because of float rounding, I proposed a few two-char changes to hopefully fix that. I'm not sure if all of them are necessary, reject the ones that don't make sense to you.
If you want to add the constraint that the base distribution needs to be univariate, then after that I'm out of nits to pick and happy to approve.
EDIT: Oh and on the formatter thing, yeah, if the formatter makes it ugly I'm happy to still go with the what the formatter does, for consistency and ease.
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Red-Portal! Some Enzyme test now seems to fail, but I'll approve since the code itself looks good. Not immediately obvious to me what the issue is, feel free to look into it, or I can try to dig into it, hopefully sometime in the next couple of days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Red-Portal, looks great! I have no idea what the right value for scale_eps
is, happy to take your word for it.
This PR adds low-rank variational families which cannot be simply represented as a location-scale family (the reparameterization path has to be modified.)
MvLocationLowRankScale
rand
logpdf
entropy
mean
,var
, andcov
The tricky part would be
logpdf
andentropy
since, to be done efficiently, will have to involve low-rank Cholesky updates. Given that low-rank Cholesky updates are niche, I am not sure whether their AD is up to the task.