Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make suffstats and fit_mle type-generic for Normal{T} #1560

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 71 additions & 53 deletions src/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,38 +120,41 @@ rand!(rng::AbstractRNG, d::Normal, A::AbstractArray{<:Real}) = A .= muladd.(d.σ

#### Fitting

struct NormalStats <: SufficientStats
s::Float64 # (weighted) sum of x
m::Float64 # (weighted) mean of x
s2::Float64 # (weighted) sum of (x - μ)^2
tw::Float64 # total sample weight
struct NormalStats{T<:Real} <: SufficientStats
s::T # (weighted) sum of x
m::T # (weighted) mean of x
s2::T # (weighted) sum of (x - μ)^2
tw::T # total sample weight
function NormalStats(s::T1, m::T2, s2::T3, tw::T4) where {T1,T2,T3,T4}
T = promote_type(T1, T2, T3, T4)
return new{T}(T(s), T(m), T(s2), T(tw))
end
end

function suffstats(::Type{<:Normal}, x::AbstractArray{T}) where T<:Real
n = length(x)

# compute s
s = zero(T) + zero(T)
s = zero(T)
for i in eachindex(x)
@inbounds s += x[i]
end
m = s / n

# compute s2
s2 = zero(m)
s2 = zero(T)
for i in eachindex(x)
@inbounds s2 += abs2(x[i] - m)
end

NormalStats(s, m, s2, n)
end

function suffstats(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real
n = length(x)

function suffstats(::Type{<:Normal}, x::AbstractArray{T1}, w::AbstractArray{T2}) where {T1<:Real,T2<:Real}
T = promote_type(T1, T2)
# compute s
tw = 0.0
s = 0.0 * zero(T)
tw = zero(T)
s = zero(T)
for i in eachindex(x, w)
@inbounds wi = w[i]
@inbounds s += wi * x[i]
Expand All @@ -160,7 +163,7 @@ function suffstats(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float
m = s / tw

# compute s2
s2 = zero(m)
s2 = zero(T)
for i in eachindex(x, w)
@inbounds s2 += w[i] * abs2(x[i] - m)
end
Expand All @@ -170,29 +173,35 @@ end

# Cases where μ or σ is known

struct NormalKnownMu <: IncompleteDistribution
μ::Float64
struct NormalKnownMu{T<:Real} <: IncompleteDistribution
μ::T
end

struct NormalKnownMuStats <: SufficientStats
μ::Float64 # known mean
s2::Float64 # (weighted) sum of (x - μ)^2
tw::Float64 # total sample weight
struct NormalKnownMuStats{T<:Real} <: SufficientStats
μ::T # known mean
s2::T # (weighted) sum of (x - μ)^2
tw::T # total sample weight
function NormalKnownMuStats(μ::T1, s2::T2, tw::T3) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
return new{T}(μ, s2, tw)
end
end

function suffstats(g::NormalKnownMu, x::AbstractArray{T}) where T<:Real
function suffstats(g::NormalKnownMu{T0}, x::AbstractArray{T1}) where {T0,T1<:Real}
T = promote_type(T0, T1)
μ = g.μ
s2 = zero(T) + zero(μ)
s2 = zero(T)
for i in eachindex(x)
@inbounds s2 += abs2(x[i] - μ)
end
NormalKnownMuStats(g.μ, s2, length(x))
end

function suffstats(g::NormalKnownMu, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real
function suffstats(g::NormalKnownMu{T0}, x::AbstractArray{T1}, w::AbstractArray{T2}) where {T0,T1<:Real,T2<:Real}
T = promote_type(T0, T1, T2)
μ = g.μ
s2 = 0.0 * abs2(zero(T) - zero(μ))
tw = 0.0
s2 = zero(T)
tw = zero(T)
for i in eachindex(x, w)
@inbounds wi = w[i]
@inbounds s2 += abs2(x[i] - μ) * wi
Expand All @@ -201,69 +210,78 @@ function suffstats(g::NormalKnownMu, x::AbstractArray{T}, w::AbstractArray{Float
NormalKnownMuStats(g.μ, s2, tw)
end

struct NormalKnownSigma <: IncompleteDistribution
σ::Float64

function NormalKnownSigma(σ::Float64)
struct NormalKnownSigma{T<:Real} <: IncompleteDistribution
σ::T
function NormalKnownSigma(σ::T) where {T}
σ > 0 || throw(ArgumentError("σ must be a positive value."))
new(σ)
return new{T}(σ)
end
end

struct NormalKnownSigmaStats <: SufficientStats
σ::Float64 # known std.dev
sx::Float64 # (weighted) sum of x
tw::Float64 # total sample weight
struct NormalKnownSigmaStats{T<:Real} <: SufficientStats
σ::T # known std.dev
sx::T # (weighted) sum of x
tw::T # total sample weight
function NormalKnownSigmaStats(σ::T1, sx::T2, tw::T3) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
return new{T}(σ, sx, tw)
end
end

function suffstats(g::NormalKnownSigma, x::AbstractArray{T}) where T<:Real
NormalKnownSigmaStats(g.σ, sum(x), Float64(length(x)))
function suffstats(g::NormalKnownSigma, x::AbstractArray{<:Real})
NormalKnownSigmaStats(g.σ, sum(x), length(x))
end

function suffstats(g::NormalKnownSigma, x::AbstractArray{T}, w::AbstractArray{T}) where T<:Real
function suffstats(g::NormalKnownSigma, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
NormalKnownSigmaStats(g.σ, dot(x, w), sum(w))
end

# fit_mle based on sufficient statistics

fit_mle(::Type{<:Normal}, ss::NormalStats) = Normal(ss.m, sqrt(ss.s2 / ss.tw))
fit_mle(::Type{D}, ss::NormalStats) where {D<:Normal} = D(ss.m, sqrt(ss.s2 / ss.tw))
fit_mle(g::NormalKnownMu, ss::NormalKnownMuStats) = Normal(g.μ, sqrt(ss.s2 / ss.tw))
fit_mle(g::NormalKnownSigma, ss::NormalKnownSigmaStats) = Normal(ss.sx / ss.tw, g.σ)

# generic fit_mle methods

function fit_mle(::Type{<:Normal}, x::AbstractArray{T}; mu::Float64=NaN, sigma::Float64=NaN) where T<:Real
if isnan(mu)
if isnan(sigma)
fit_mle(Normal, suffstats(Normal, x))
function fit_mle(
::Type{D}, x::AbstractArray{<:Real};
mu::Union{Nothing,<:Real}=nothing, sigma::Union{Nothing,<:Real}=nothing
) where {D<:Normal}
if isnothing(mu)
if isnothing(sigma)
fit_mle(D, suffstats(Normal, x))
else
g = NormalKnownSigma(sigma)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should find a way to preserve/forward the type of D in these cases as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which eltype should we pick for D though, the one of g or the one of the suffstats?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a conversion in my latest commit, what do you think?

fit_mle(g, suffstats(g, x))
convert(D, fit_mle(g, suffstats(g, x)))
end
else
if isnan(sigma)
if isnothing(sigma)
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x))
convert(D, fit_mle(g, suffstats(g, x)))
else
Normal(mu, sigma)
D(mu, sigma)
end
end
end

function fit_mle(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64}; mu::Float64=NaN, sigma::Float64=NaN) where T<:Real
if isnan(mu)
if isnan(sigma)
fit_mle(Normal, suffstats(Normal, x, w))
function fit_mle(
::Type{D}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real};
mu::Union{Nothing,<:Real}=nothing, sigma::Union{Nothing,<:Real}=nothing
) where {D<:Normal}
if isnothing(mu)
if isnothing(sigma)
fit_mle(D, suffstats(Normal, x, w))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x, w))
convert(D, fit_mle(g, suffstats(g, x, w)))
end
else
if isnan(sigma)
if isnothing(sigma)
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x, w))
convert(D, fit_mle(g, suffstats(g, x, w)))
else
Normal(mu, sigma)
D(mu, sigma)
end
end
end
Loading