Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace old Gibbs sampler with the experimental one. #2328

Open
wants to merge 50 commits into
base: master
Choose a base branch
from

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Sep 23, 2024

Closes #2318.

Work in progress.

Copy link

codecov bot commented Sep 23, 2024

Codecov Report

Attention: Patch coverage is 51.33929% with 109 lines in your changes missing coverage. Please review.

Project coverage is 73.80%. Comparing base (5b24ceb) to head (053ecfc).

Files with missing lines Patch % Lines
src/mcmc/gibbs.jl 51.23% 99 Missing ⚠️
src/mcmc/sghmc.jl 0.00% 4 Missing ⚠️
src/mcmc/Inference.jl 0.00% 2 Missing ⚠️
src/mcmc/emcee.jl 0.00% 1 Missing ⚠️
src/mcmc/ess.jl 0.00% 1 Missing ⚠️
src/mcmc/is.jl 0.00% 1 Missing ⚠️
src/mcmc/particle_mcmc.jl 66.66% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (5b24ceb) and HEAD (053ecfc). Click for more details.

HEAD has 33 uploads less than BASE
Flag BASE (5b24ceb) HEAD (053ecfc)
65 32
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2328       +/-   ##
===========================================
- Coverage   86.41%   73.80%   -12.61%     
===========================================
  Files          22       20        -2     
  Lines        1575     1554       -21     
===========================================
- Hits         1361     1147      -214     
- Misses        214      407      +193     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Sep 23, 2024

Pull Request Test Coverage Report for Build 11821808857

Details

  • 108 of 224 (48.21%) changed or added relevant lines in 10 files are covered.
  • 103 unchanged lines in 5 files lost coverage.
  • Overall coverage decreased (-12.8%) to 65.762%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/emcee.jl 0 1 0.0%
src/mcmc/ess.jl 0 1 0.0%
src/mcmc/is.jl 0 1 0.0%
src/mcmc/particle_mcmc.jl 2 3 66.67%
src/mcmc/Inference.jl 0 2 0.0%
src/mcmc/sghmc.jl 0 4 0.0%
src/mcmc/hmc.jl 0 7 0.0%
src/mcmc/gibbs.jl 104 203 51.23%
Files with Coverage Reduction New Missed Lines %
src/mcmc/hmc.jl 1 0.0%
src/mcmc/abstractmcmc.jl 8 78.72%
src/mcmc/particle_mcmc.jl 11 86.75%
src/mcmc/Inference.jl 31 67.66%
src/mcmc/ess.jl 52 0.0%
Totals Coverage Status
Change from base Build 11725449698: -12.8%
Covered Lines: 1018
Relevant Lines: 1548

💛 - Coveralls


The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.

Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)

This looks rather awkward. Can we introduce a simple wrapper, Repeated and support:

Gibbs(@varname(x) => Repeated(hmc, n), @varname(y) => mh)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We had a chat about a closely related issue with @torfjelde too, I'll rework the interface around this a bit.

@mhauru
Copy link
Member Author

mhauru commented Sep 26, 2024

@torfjelde, if you have a moment to take a look at the one remaining test failure, would be interested in your thoughts. We are sampling for a model with two vector variables, m and z, and we seem to somehow end up with a case where there's a VarInfo with only z in it, but the sampler is looking for m too. I wonder if it's something about the interaction between particle sampling with Libtask and how the new Gibbs does things with the local varinfos. The test that fails is this one:

    @testset "dynamic model" begin
        @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M}
            N = length(y)
            rpm = DirichletProcess(alpha)

            z = zeros(Int, N)
            cluster_counts = zeros(Int, N)
            fill!(cluster_counts, 0)

            for i in 1:N
                z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts)
                cluster_counts[z[i]] += 1
            end

            Kmax = findlast(!iszero, cluster_counts)
            m = M(undef, Kmax)
            for k in 1:Kmax
                m[k] ~ Normal(1.0, 1.0)
            end
        end
        model = imm(Random.randn(100), 1.0)
        # https://github.com/TuringLang/Turing.jl/issues/1725
        # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100);
        sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100)
    end

@torfjelde
Copy link
Member

Will have a look at this in a bit @mhauru (just need to do some grocery shopping 😬 )

@mhauru
Copy link
Member Author

mhauru commented Sep 26, 2024

Collecting links to old relevant PRs so I don't have to look for them again: #2231, #2099

@torfjelde
Copy link
Member

Think I found the error: if the number of m increases, say, from length(m) = 2 to length(m) = 3 during the PG step, then the lines

if has_conditioned_gibbs(context, vn)
value = get_conditioned_gibbs(context, vn)
return value, logpdf(right, value), vi
end
# Otherwise, falls back to the default behavior.
return DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)

doesn't hit the Gibbs branch since @varname(m[3]) is not present in the GibbsContext 😕

@torfjelde
Copy link
Member

doesn't hit the Gibbs branch since @varname(m[3]) is not present in the GibbsContext

I'm a bit uncertain how we should best handle this @yebai @mhauru

The first partially viable idea that comes to mind is to subset the varinfo to make sure that it only contains the correct variables. If we do this, then m[3] will just be "ignored" (in the varinfos) until we're actually sampling the m variables, in which case it would be captured correctly.

But this would not quite be equivalent to the current implementation of Gibbs, which, AFAIK, keeps the very first occurence of m around rather than resampling everytime. And naively, I would expect this to be incorrect.

Another way is to explicitly add the varinfos to the GibbsContext itself, and then, when we encounter a value that should in fact go into a different varinfo, we add it there. But this has a few issues:

  1. Requires the VarInfo to be mutable.
  2. Requires the VarInfo to have a container that can keep the new incoming value m[3].
  3. Implementation of Gibbs does end up being more complicated than the current approach. However, it might be worth it.

Thoughts?

@yebai
Copy link
Member

yebai commented Sep 27, 2024

Another way is to explicitly add the varinfos to the GibbsContext itself, and then, when we encounter a value that should in fact go into a different varinfo, we add it there. But this has a few issues:

Requires the VarInfo to be mutable.
Requires the VarInfo to have a container that can keep the new incoming value m[3].
Implementation of Gibbs does end up being more complicated than the current approach. However, it might be worth it.

I lean towards the above approach and (maybe later) provide explicit APIs to inference algorithms. This will enable us to handle reversible jumps (varying model dimensions) in MCMC more flexibly. At the moment, this is only possible in particle Gibbs; if it happens in HMC/MH, inference will likely fail (silently)

EDIT: we can keep VarInfos immutable by default, and requires inference developers to hook into specific APIs to mutate VarInfos.

@torfjelde
Copy link
Member

This does however complicate the new Gibbs sampling procedure quite drastically 😕

And it makes me bring up a question I really didn't think I'd be asking: is it then actually preferable to the current Gibbs with keeping it all in a single VarInfo with a flag to specify whether it should be sampled or not? 😬

I guess we should first have a go at implementing this for the new Gibbs and then we can see 👍

Another point to add to the conversation that @mhauru brought to my attention the other day: we also want to support stuff like Gibbs(@varname(m) => NUTS(), @varname(m) => HMC()), i.e. multiple samplers targeting the same variables. This adds a few "complications" (beyond addressing the growing model problem discussed above):

  1. Need to determine which varinfo to pick from varinfos based on the varnames present / targeted.
  2. A naive implementation will result in duplicated entries in varinfos. We can however address this if we really feel like it's worth it, so probably a non-issue atm.

So all in all, immediate things we need to address with Gibbs:

  1. Support changing dimensions.
  2. Support picking a varinfo to condition on based on the varnames present rather than based on ===.

@mhauru
Copy link
Member Author

mhauru commented Oct 10, 2024

I've been trying to think of a way to fix this, that would also fix the problem where different Gibbs subsamplers can't sample the same variables (e.g. you can't first sample x and y using one sampler, and then y and z with a different one). My best thought at the moment is the following design:

  1. There is only one, global VarInfo, call it vi.
  2. make_conditional takes that vi and a list of VarNames that the current subsampler samples. It hijacks the tilde pipeline to condition all other variables to their current values in vi.
  3. vi may have some variables linked, some not.
  4. Every time we call a subsampler we can hand it vi as the VarInfo. It won’t mess with any of the variables it’s not supposed to touch, because the tilde pipeline hijack from point 2.

Point 3. is maybe undesirable, but I think it’s minor compared to all the Selector/gibbsid stuff, which we would still get rid of.

The only problem I see with this is combining the local state from the previous iteration of the current subsampler with the global vi. Somehow we would need to join up-to-date information from the global vi with state-information from the previous iteration, specific to this subsampler. The right way to do this depends on the state, which is a different type of object for different subsamplers. EDIT: Actually, maybe this is okay, because we seem to already assume that every state object has a field called state.vi , we could just reset that.

The great benefit of sticking to one, global VarInfo is never having to worry about moving data between the local VarInfos. That would have to happen in both cases, when a new variable is introduced by one sampler (the failing test in this PR) and when two samplers sample the same variable. It sounds like a pain to implement.

@mhauru
Copy link
Member Author

mhauru commented Oct 10, 2024

I can imagine two different philosophies to implementing a Gibbs sampler:

  1. Every subsampler is doing its own sampling process on a low-dimensional model (a conditioned version of the full model), independent of the others. The logprobability function it's sampling from just keeps changing between iterations, because the other variables change and thus the conditioned model changes, but otherwise it's blind to the existence of the variables it isn't sampling. This is what the new Gibbs sampler does.
  2. Every subsampler is working with the same, full model, with all the variables, but only makes the changes to a subset of those variables. It still "sees" the whole model. This is what the old Gibbs sampler did.

My above proposal would essentially be doing 2., but using code that's very much like the new sampler, where the information about which sampler modifies which variables is in the sampler/GibbsContext, and not in VarInfo like it was in the old Gibbs.

The reason I'm leaning towards 2. is that 1. seems to run to some fundamental issues in cases where either

  • Variables appear and disappear based on values of other variables,
  • Two samplers want to modify the value of the same variable.

Both of those situations quite deeply violate the idea that the different subsamplers can operate mostly independently of each other.

Any thoughts very welcome, I'm still very much trying to understand the landscape of the problem.

@yebai
Copy link
Member

yebai commented Oct 10, 2024

Thanks, @mhauru, for the excellent summary of the problem and proposals. Storing conditioned variables in a context, like GibbsContext as you suggested, is very sensible. The consequence is that VarInfo and Context will have overlapped model parameters, e.g. conditioned variables will be found in both VarInfo and Context, which is fine.

In addition, it's worth mentioning that we currently have two mechanisms for passing observations to models, i.e.

(1) via model arguments, e.g. gdemo(x, y).
(2) via condition API, e.g. condition(model, (x=1,y=2)).

Among these options, (1) will hardcode observation information directly in the model while (2) stores them in a context. You could look at the DynamicPPL codebase for a more detailed picture of how it works. We want to unify these options, perhaps towards using (2) only.

This Gibbs refactoring could be an excellent starting point for a design_notes repo to record these thoughts and discussions.

@torfjelde
Copy link
Member

Every subsampler is working with the same, full model, with all the variables, but only makes the changes to a subset of those variables. It still "sees" the whole model. This is what the old Gibbs sampler did.

Overall, I'm also in favour of this @mhauru 👍 I think your reasoning is solid here.

The only other "option" I'm seeing is to keep track of which variables correpond to which varinfos (with each varinfo only containing the relevant information), but then we're effectively just re-implementing a lot of the functionality that is already provided in varinfo 😕

The only "issue" is that this does mean we have to support this "link / transform only part of the varinfo, which does mean we need something "equivalent" to all the getindex(varinfo, sampler) stuff that we've been trying to move away from (since we need a way to extract the vectorized part relevant only for the specific sampler we're going to use in that particular step) 😕

Doulby however, I think we can make this much nicer than the current approach by simply making all these getindex(varinfo, sampler) instead take the relevant varnames instead of the samplers themselves, which should make it all less painful.

But yeah, don't see how we can take approach (1) in a "nice" way, and so I'm also in favour of just trying to make (2) as painless as possible to maintain.

@mhauru
Copy link
Member Author

mhauru commented Oct 11, 2024

Thanks for the comments both, this is very helpful.

Doulby however, I think we can make this much nicer than the current approach by simply making all these getindex(varinfo, sampler) instead take the relevant varnames instead of the samplers themselves, which should make it all less painful.

Yeah, I think this is the way to go.

@mhauru mhauru marked this pull request as ready for review October 29, 2024 13:58
@mhauru mhauru requested a review from torfjelde October 29, 2024 13:58
@mhauru
Copy link
Member Author

mhauru commented Nov 1, 2024

@torfjelde, I think I made all the changes we discussed on the call. I may still move the setparams!! methods to a different file, but that's details. I hope tests will pass with ReverseDiff, but I didn't have time to run the whole suite locally before having to go, so let's see.

@mhauru
Copy link
Member Author

mhauru commented Nov 5, 2024

All the tests have passed on at least one platform, so I assume they'll all eventually pass. @torfjelde, wanna take another look?

@mhauru
Copy link
Member Author

mhauru commented Nov 5, 2024

Some optimisation is clearly required. The following runs in sub 10s with the Gibbs sampler, takes more than a minute with new one:

julia> module Benchmark1
       using Turing, Test, Random
       using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
       import ReverseDiff

       adbackend = AutoReverseDiff()
       @testset "dynamic model" begin
           @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M}
               N = length(y)
               rpm = DirichletProcess(alpha)

               z = zeros(Int, N)
               cluster_counts = zeros(Int, N)
               fill!(cluster_counts, 0)

               for i in 1:N
                   z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts)
                   cluster_counts[z[i]] += 1
               end

               Kmax = findlast(!iszero, cluster_counts)
               m = M(undef, Kmax)
               for k in 1:Kmax
                   m[k] ~ Normal(1.0, 1.0)
               end
           end
           model = imm(Random.randn(100), 1.0)
           # https://github.com/TuringLang/Turing.jl/issues/1725
           # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100);
           sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 200)
       end

       end

Here's another one. With the new sampler:

julia> module Benchmark2
       using Turing, Random
       import ReverseDiff

       adbackend = AutoReverseDiff()

       @model function gdemo(x, y)
           s ~ InverseGamma(2, 3)
           m ~ Normal(0, sqrt(s))
           x ~ Normal(m, sqrt(s))
           y ~ Normal(m, sqrt(s))
           return s, m
       end

       Random.seed!(100)
       alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend))
       @time sample(gdemo(1.5, 2.0), alg, 1_000)
       end
WARNING: replacing module Benchmark2.
 38.251275 seconds (13.16 M allocations: 965.908 MiB, 0.70% gc time, 7.68% compilation time)
Main.Benchmark2

With the old Gibbs sampler:

 16.980426 seconds (13.63 M allocations: 1.046 GiB, 1.01% gc time, 14.13% compilation time)

There are some obvious optimisations to do, I'll see where they get us.

@mhauru
Copy link
Member Author

mhauru commented Nov 7, 2024

@torfjelde and I had a conversation about performance and optimisation thereof, and concluded that we should do none of it now. Tor's argument was that the only significant cases are heavy models with lots of variables, in which case all the cost is going to be in the gradient/logprob evaluations anyway, and your component samplers will probably do lots of them. So as long as we have type stability in the tilde pipeline, everything else is details.

Based on that I

  1. Added some explicit type stability tests to the Gibbs test suite.
  2. Removed recompute_logprob!!. We call setparams!! at every iteration, and that one requires recomputing the logprob anyway.
  3. Removed the small of piece of machinery that was in place to give the user the option to not recompute logprob in some cases. This would have gotten significantly more complicated to do correctly for setparams!!, now that we can have multiple samplers affecting the same variables.

So in short, we are recomputing the logprob at every iteration of every component sampler in the setparams!! call, and we've decided that this doesn't matter enough at the moment to warrant optimising.

Tests pass, ready for review again @torfjelde.

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! Took a bit to get through it properly.

A few "major" things:

  1. I think the correct thing to do is to use the GibbsContext as the very last context before we hit the leafcontext, i.e. it has to be an immediate parent of a leaf. See comments on relevant code for more info.
  2. I think I have a solution to the setgid!! stuff that you've commented on.

src/mcmc/abstractmcmc.jl Outdated Show resolved Hide resolved
src/mcmc/gibbs.jl Outdated Show resolved Hide resolved
src/mcmc/gibbs.jl Show resolved Hide resolved
test/mcmc/gibbs.jl Outdated Show resolved Hide resolved
test/mcmc/gibbs.jl Outdated Show resolved Hide resolved
src/mcmc/gibbs.jl Outdated Show resolved Hide resolved
# TODO(mhauru) Remove the below loop once samplers no longer depend on selectors.
# For some reason not having this in place was causing trouble for ESS, but not for
# other samplers. I didn't get to the bottom of it.
for vn in keys(varinfo_local)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed after drop_space

test/mcmc/gibbs.jl Outdated Show resolved Hide resolved
test/mcmc/gibbs.jl Outdated Show resolved Hide resolved
test/mcmc/gibbs.jl Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove old Gibbs sampler, make the experimental one the default
5 participants