Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement congruency check in EnzymeCore #637

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ version = "0.2.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"

[compat]
Adapt = "3.3"
Preferences = "1.3"
julia = "1.6"
79 changes: 79 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
module EnzymeCore

using Adapt
using Preferences

export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal
export ReverseSplitModified, ReverseSplitWidth
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed

const structure_check = parse(Bool, @load_preference("structure_check", "false"))

"""
structure_check!(flag)

Toggle the default setting for congruence/structure checking.
"""
function structure_check!(flag)
@set_preferences!("structure_check" => flag)
@info("structure_check toggled, restart your Julia session for this change to take effect!")

if VERSION <= v"1.6.5" || VERSION == v"1.7.0"
@warn """
Due to a bug in Julia (until 1.6.5 and 1.7.1), setting preferences in transitive dependencies
is broken (https://github.com/JuliaPackaging/Preferences.jl/issues/24). To fix this either update
your version of Julia, or add EnzymeCore as a direct dependency to your project.
"""
end
return nothing
end

function batch_size end

"""
Expand Down Expand Up @@ -61,6 +83,10 @@ accumulator for gradients (so ``\\partial f / \\partial x`` will be *added to*)
struct Duplicated{T} <: Annotation{T}
val::T
dval::T
function Duplicated(val::T, dval::T; checked=structure_check) where T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the preference isn't there, will this be a static no check, or will it load a global then check?

If the latter, I want to have a way to avoid the performance penalty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a global const so it's propagated (and it's a compile time preference)

checked && check_congruence(val, dval)
new{T}(val, dval)
end
end
Adapt.adapt_structure(to, x::Duplicated) = Duplicated(adapt(to, x.val), adapt(to, x.dval))

Expand All @@ -73,6 +99,10 @@ the original result and only compute the derivative values.
struct DuplicatedNoNeed{T} <: Annotation{T}
val::T
dval::T
function DuplicatedNoNeed(val::T, dval::T; checked=structure_check) where T
checked && check_congruence(val, dval)
new{T}(val, dval)
end
end
Adapt.adapt_structure(to, x::DuplicatedNoNeed) = DuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval))

Expand All @@ -85,6 +115,10 @@ for all at once. Argument `∂f_∂xs` should be a tuple of the several values o
struct BatchDuplicated{T,N} <: Annotation{T}
val::T
dval::NTuple{N,T}
function BatchDuplicated(val::T, dval::NTuple{N,T}; checked=structure_check) where {T, N}
checked && check_congruence(val, dval)
new{T, N}(val, dval)
end
end
Adapt.adapt_structure(to, x::BatchDuplicated) = BatchDuplicated(adapt(to, x.val), adapt(to, x.dval))

Expand All @@ -97,11 +131,56 @@ for all at once. Argument `∂f_∂xs` should be a tuple of the several values o
struct BatchDuplicatedNoNeed{T,N} <: Annotation{T}
val::T
dval::NTuple{N,T}
function BatchDuplicatedNoNeed(val::T, dval::NTuple{N,T}; checked=structure_check) where {T, N}
checked && check_congruence(val, dval)
new{T, N}(val, dval)
end
end
batch_size(::BatchDuplicated{T,N}) where {T,N} = N
batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N
Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) = BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval))

"""
congruent(a::T, b::T)::T

Defines values to be congruent, e.g. structurally equivalent.
"""
function congruent end

congruent(a, b) = false
congruent(a::T, b::T) where T<:Number = true
function congruent(a::T, b::T) where T<:DenseArray
axes(a) == axes(b) && all(congruent, zip(a, b))
end
congruent(a::T, b::T) where T<:Ref = congruent(a[], b[])
congruent(a::T, b::T) where T<:Tuple = all(congruent, zip(a, b))
congruent(a::T, b::T) where T<:NamedTuple = all(congruent, zip(a, b))

congruent(tup::Tuple{T, T}) where T = congruent(tup...)

function check_congruence(a::T, b::T) where T
# TODO: Use once hasmethod is static
# if !hasmethod(congruent, Tuple{T, T})
# error("""
# Implement EnzymeCore.congruent(a, b) for your type $T
# """)
# end
if !congruent(a, b)
error("""
Your values are not congruent, structural equivalence is
requirement for the correctness of the adjoint pass.

You may need to implement EnzymeCore.congruent(a, b) for your type $T
""")
end
end

function check_congruence(a::T, b::NTuple{N, T}) where {N, T}
ntuple(Val(N)) do i
check_congruence(a, b[i])
end
end

"""
abstract type Mode

Expand Down