Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for distributions with monotonically increasing bijector #297

Merged
merged 51 commits into from
Jun 27, 2024

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Dec 3, 2023

Related: #220 and #295

src/bijectors/ordered.jl Outdated Show resolved Hide resolved
src/interface.jl Show resolved Hide resolved
torfjelde and others added 2 commits December 3, 2023 16:59
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

@sethaxen maybe you want to have a look at this

@yebai yebai requested a review from sunxd3 December 4, 2023 13:40
@sunxd3
Copy link
Member

sunxd3 commented Dec 4, 2023

Looks good to me, but maybe we want to wait for @sethaxen

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks right! I suggest we add overloads for more bijectors. Is there a reason this PR doesn't also add is_monotonically_decreasing?

All univariate bijections are strictly monotonic. So we could define is_monotonically_decreasing(b) = !ismonotonically_increasing(b) if we documented that this function is only expected to give the correct answer when a univariate bijector is passed. But this could cause problems. Do we have any way to statically detect if a bijector is univariate?

The PR needs tests to cover each of the additions.

src/interface.jl Outdated
function is_monotonically_increasing(cf::ComposedFunction)
return is_monotonically_increasing(cf.inner) && is_monotonically_increasing(cf.outer)
end
is_monotonically_increasing(::typeof(exp)) = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(log)) = true
is_monotonically_increasing(binv::Inverse) = is_monotonically_increasing(inverse(b))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an interface function, would it be better to place these methods in the files where the corresponding bijectors are implemented? Also, I think we can mark this as true for Logit, LeakyReLu, Scale (when scale is positive), Shift, and TruncatedBijector.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed:)

@torfjelde
Copy link
Member Author

torfjelde commented Dec 10, 2023

Is there a reason this PR doesn't also add is_monotonically_decreasing?

I did consider this, but AFAIK the only monotonically decreasing bijectors we have right now is Scale with negative coefficients which will require runtime checks and thus made me hesitant (was going to raise an issue about this).

But it's probably worth it, so I'll add that too 👍

Do we have any way to statically detect if a bijector is univariate?

Not at the moment, no.

All univariate bijections are strictly monotonic. So we could define

And because we don't, I'd prefer to make it all explicit so we end up with a method error / always return false instead of silently doing something strange.

Co-authored-by: Seth Axen <[email protected]>
@torfjelde
Copy link
Member Author

torfjelde commented Dec 10, 2023

Ah okay so now I remember another reason why I was holding back on is_monotonically_decreasing: AFAIK Scale is the only monotonically decreasing function we can have, but how do we implement is_monotonically_decreasing for ComposedFunction?

The condition

is_monotonically_decreasing(f.inner) && is_monotonically_decreasing(f.outer)

won't be correct, e.g. Scale(-1) and Scale(-1) are both monotonically decreasing, but their composition is not.

EDIT: Though this is of course also an issue for is_monotonically_increasing...

EDIT 2: Nvm, it all just boils down to

inner \ outer inc dec other
inc inc dec NA
dec dec inc NA
other NA NA NA

@sethaxen
Copy link
Member

I don't understand the table, but I believe it amounts to first checking that all bijectors are (elementwise) univariate with all(x -> is_monotonically_increasing(x) | is_monotonically_decreasing(x), bijectors) and then checking that there are an odd number of decreasing bijectors with mapreduce(is_monotonically_decreasing, xor, bijectors).

src/bijectors/logit.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

My table is conveying the same idea, just on a per-composition-basis (since we're defining the method for ComposedFunction) :)

But I've now added support for monotonically decreasing functions too + tests:)

@torfjelde
Copy link
Member Author

Aaaalrighty! Final got this thing working:)

Issue was that we added one too many transformations @sethaxen : should just be inverse(OrderedBijector()) ∘ b to we get a transformation from constrained to real, not binv ∘ inverse(OrderedBijector()) ∘ b, which takes us from constrained to constrained.

BUT one final thing: what should we put as a warning regarding usage of ordered @sethaxen? I'm still a bit uncertain about exactly what you meant; I thought I understood wrt. restriction and not accounting for normalization constant, but then the example with changing normalization constant (variance parameter changing in a MvNormal) seemed to work, so now I'm confused again 🤷

@torfjelde
Copy link
Member Author

Damn. Seems like we missed something in #313

@torfjelde
Copy link
Member Author

Note that there doesn't seem to be anything incorrect with the impl, but it's failing because it's trying to compare elements which aren't part of the triangular part

@torfjelde
Copy link
Member Author

Pfft well that was painful. Added comments regarding what the issue is + fixed it by introducing a wrapper to avoid comparing Matrix values with potentially undef entries.

Would you have a quick look at some point @sethaxen ? 🙏 Think we're there now after we've addressed the following:)

BUT one final thing: what should we put as a warning regarding usage of ordered @sethaxen? I'm still a bit uncertain about exactly what you meant; I thought I understood wrt. restriction and not accounting for normalization constant, but then the example with changing normalization constant (variance parameter changing in a MvNormal) seemed to work, so now I'm confused again 🤷

test/ad/chainrules.jl Outdated Show resolved Hide resolved
test/ad/chainrules.jl Outdated Show resolved Hide resolved
test/ad/chainrules.jl Outdated Show resolved Hide resolved
torfjelde and others added 5 commits June 5, 2024 13:46
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sethaxen
Copy link
Member

sethaxen commented Jun 5, 2024

Would you have a quick look at some point @sethaxen ? 🙏 Think we're there now after we've addressed the following:)

Cool will try to review this evening.

BUT one final thing: what should we put as a warning regarding usage of ordered @sethaxen? I'm still a bit uncertain about exactly what you meant; I thought I understood wrt. restriction and not accounting for normalization constant, but then the example with changing normalization constant (variance parameter changing in a MvNormal) seemed to work, so now I'm confused again 🤷

Sounds weird, I'll check the test.

@torfjelde
Copy link
Member Author

Sounds weird, I'll check the test.

I was talking about the example that we were discussing in one of the other comments; specifically #297 (comment)

@sethaxen
Copy link
Member

sethaxen commented Jun 6, 2024

Ah, that's expected though. I assume you un-fixed the variance parameter and randomly sampled it within the rejection sampling inner loop? The issue here is that when the mean is the same for both components, then the variance actually has no impact on whether they are ordered. I think you should see a difference if you make the mean a reverse-ordered vector. The further the two mean components, the more pronounced the difference and the harder it is to rejection sample.

@torfjelde
Copy link
Member Author

the example I talk about is not related to rejection sampling; I'm referring to the example you ran with NUTS:)

@sethaxen
Copy link
Member

sethaxen commented Jun 6, 2024

I was talking about the example that we were discussing in one of the other comments; specifically #297 (comment)

the example I talk about is not related to rejection sampling; I'm referring to the example you ran with NUTS:)

I'm confused which example you're referring to then. The one in the comment you linked to compares NUTS with rejection sampling, but it does so with fixed variance, so it would not manifest the issue I'm talking about. Here's an example that does:

example
using Turing
using Bijectors: ordered
using LinearAlgebra
using Random: Random
using PosteriorStats

@model function demo_ordered(μ)
    k = length(μ)
    σ² ~ filldist(truncated(Normal(), lower=0), k)
    x ~ ordered(MvNormal(μ, Diagonal(σ²)))
    return (; σ², x)
end

k = 2
num_samples = 1_000_000
num_chains = 8

# Sample using NUTS.
μ = [3, 0]  # note: reverse-ordered, most draws will be rejected
model = demo_ordered(μ)

Random.seed!(0)
chain = sample(model, NUTS(), MCMCThreads(), num_samples ÷ num_chains, num_chains)
xs_chain = permutedims(Array(chain))

σ²_chain = cat(only(get(chain, :σ²))...; dims=3)

# Rejection sampling.
σ²_exact = mapreduce(hcat, 1:num_samples) do _
    while true
        σ² = rand(filldist(truncated(Normal(), lower=0), k))
        d = MvNormal(μ, Diagonal(σ²))    
        xs = rand(d)
        issorted(xs) && return σ²
    end
end

qts = [0.05, 0.25, 0.5, 0.75, 0.95]
qt_names = map(q -> Symbol("q$(Int(100 * q))"), qts)
stats_with_mcses = (
    Tuple(qt_names) => Base.Fix2(quantile, qts),
    (Symbol("$(qn)_mcse") => (x -> mcse(x; kind=Base.Fix2(quantile, q))) for (q, qn) in zip(qts, qt_names))...,
)
julia> PosteriorStats.summarize(σ²_chain, stats_with_mcses...; var_names=["σ²[1]", "σ²[2]"])
SummaryStats
           q5    q25    q50    q75    q95  q5_mcse  q25_mcse  q50_mcse  q75_mcse  q95_mcse 
 σ²[1]  0.239  0.769  1.249  1.77   2.531   0.0031    0.0024    0.0030    0.0053    0.0031
 σ²[2]  0.17   0.74   1.220  1.740  2.533   0.012     0.0052    0.0031    0.0025    0.0026

julia> PosteriorStats.summarize(reshape(σ²_exact', :, 1, 2), stats_with_mcses...; var_names=["σ²[1]", "σ²[2]"])
SummaryStats
           q5    q25    q50    q75    q95  q5_mcse  q25_mcse  q50_mcse  q75_mcse  q95_mcse 
 σ²[1]  0.226  0.758  1.234  1.747  2.538  0.00087   0.00091   0.00092    0.0010    0.0017
 σ²[2]  0.227  0.759  1.235  1.749  2.535  0.00081   0.00096   0.00094    0.0010    0.0020

Note that the rejection sampling approach makes sense. The quantiles of the two variances should be about the same, since to get an ordered draw with a well-separated reverse-ordered mean, one needs to increase the variance, but it doesn't matter which variance is increased. But if we look at the HMC draws, we see that there's an asymmetry between the variances. This is due to the missing normalization factor. If we had a closed-form expression for it, we could test that, but I don't know one.

@sethaxen
Copy link
Member

sethaxen commented Jun 6, 2024

TBH I'm not certain if the above examples are even correct. The place I expect this to manifest is when conditioning. Which is implicitly what the rejection-sampling approach is doing (conditioning on x[1] > x[2]).

@torfjelde
Copy link
Member Author

torfjelde commented Jun 6, 2024

I'm confused which example you're referring to then. The one in the comment you linked to compares NUTS with rejection sampling, but it does so with fixed variance, so it would not manifest the issue I'm talking about.

Completely missed the fact that we were fixing the variance 🤦

Note that the rejection sampling approach makes sense. The quantiles of the two variances should be about the same, since to get an ordered draw with a well-separated reverse-ordered mean, one needs to increase the variance, but it doesn't matter which variance is increased.

Gotcha, gotcha; understand better now 👍

Soooo how do we summarize all this into a simple warning for the end-user? 👀

@torfjelde
Copy link
Member Author

This PR is just waiting for the following:

Soooo how do we summarize all this into a simple warning for the end-user?

Think it's worth waiting until @sethaxen is back to let him have a final say before we merge 👍

@sethaxen
Copy link
Member

sethaxen commented Jun 26, 2024

Soooo how do we summarize all this into a simple warning for the end-user? 👀

Maybe an admonition saying something like:

The resulting ordered distribution is un-normalized. This is not a problem if used in a context where the normalizing factor is irrelevant, but if the value of the normalizing factor impacts the resulting computation, the results may be inaccurate. For example, if the distribution is used in sampling a posterior distribution with MCMC and the parameters of the ordered distribution are themselves sampled, then the normalizing factor would in general be needed for accurate sampling, and ordered should not be used. However, if the parameters are fixed, then since MCMC does not require distributions be normalized, ordered may be used without problems. A common case is where the distribution being ordered is a joint distribution of n identical univariate distributions. In this case the normalization factor works out to be the constant n!, and ordered can again be used without problems even if the parameters of the univariate distribution are sampled.

Not in love with it; feels too wordy.

@torfjelde
Copy link
Member Author

Not in love with it; feels too wordy.

Added it, but did broke it up a bit + added a shorter warning to the initial parts of the docstring:) Thanks @sethaxen !

@yebai yebai requested a review from sethaxen June 27, 2024 09:11
Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @torfjelde for taking on this fix!

@yebai yebai merged commit 026a07a into master Jun 27, 2024
23 checks passed
@yebai yebai deleted the torfjelde/ordered-for-monotonic branch June 27, 2024 19:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants