From c0a8e2d3388fe2105eac2e7e01ae06fe96857443 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 23 Feb 2023 18:43:13 -0500 Subject: [PATCH 1/4] Implement congruency check in EnzymeCore --- lib/EnzymeCore/src/EnzymeCore.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 2658d66acf..5e9e5c4dc5 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -61,6 +61,10 @@ accumulator for gradients (so ``\\partial f / \\partial x`` will be *added to*) struct Duplicated{T} <: Annotation{T} val::T dval::T + function Duplicated(val::T, dval::T) where T + check_congruence(val, dval) + new{T}(val, dval) + end end Adapt.adapt_structure(to, x::Duplicated) = Duplicated(adapt(to, x.val), adapt(to, x.dval)) @@ -102,6 +106,31 @@ batch_size(::BatchDuplicated{T,N}) where {T,N} = N batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) = BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) +""" + congruent(a::T, b::T)::T + +Defines values to be congruent, e.g. structurally equivalent. +""" +function congruent end + +congruent(a::T, b::T) where T<:Number = true +congruent(a::T, b::T) where T<:AbstractArray = length(a) == length(b) + +function check_congruence(a::T, b::T) where T + # TODO: Use once hasmethod is static + # if !hasmethod(congruent, Tuple{T, T}) + # error(""" + # Implement EnzymeCore.congruent(a, b) for your type $T + # """) + # end + if !congruent(a, b) + error(""" + Your values are not congruent, structural equivalence is + requirement for the correctness of the adjoint pass. + """) + end +end + """ abstract type Mode From 25097dcd497bd70da3de721517f20cfe8204be8b Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 12 Mar 2023 13:46:48 -0400 Subject: [PATCH 2/4] Optional congruence check with Preferences based opt-in --- lib/EnzymeCore/Project.toml | 2 ++ lib/EnzymeCore/src/EnzymeCore.jl | 39 ++++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index febc08bfa6..6e94c9eeee 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -5,7 +5,9 @@ version = "0.2.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" [compat] Adapt = "3.3" +Preferences = "1.3" julia = "1.6" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 5e9e5c4dc5..f3e480cc18 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -1,11 +1,14 @@ module EnzymeCore using Adapt +using Preferences export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal export ReverseSplitModified, ReverseSplitWidth export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed +const structure_check = parse(Bool, @load_preference("structure_check", "false")) + function batch_size end """ @@ -61,8 +64,8 @@ accumulator for gradients (so ``\\partial f / \\partial x`` will be *added to*) struct Duplicated{T} <: Annotation{T} val::T dval::T - function Duplicated(val::T, dval::T) where T - check_congruence(val, dval) + function Duplicated(val::T, dval::T; checked=structure_check) where T + checked && check_congruence(val, dval) new{T}(val, dval) end end @@ -77,6 +80,10 @@ the original result and only compute the derivative values. struct DuplicatedNoNeed{T} <: Annotation{T} val::T dval::T + function DuplicatedNoNeed(val::T, dval::T; checked=structure_check) where T + checked && check_congruence(val, dval) + new{T}(val, dval) + end end Adapt.adapt_structure(to, x::DuplicatedNoNeed) = DuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) @@ -89,6 +96,10 @@ for all at once. Argument `∂f_∂xs` should be a tuple of the several values o struct BatchDuplicated{T,N} <: Annotation{T} val::T dval::NTuple{N,T} + function BatchDuplicated(val::T, dval::NTuple{N,T}; checked=structure_check) where {T, N} + checked && check_congruence(val, dval) + new{T, N}(val, dval) + end end Adapt.adapt_structure(to, x::BatchDuplicated) = BatchDuplicated(adapt(to, x.val), adapt(to, x.dval)) @@ -101,6 +112,10 @@ for all at once. Argument `∂f_∂xs` should be a tuple of the several values o struct BatchDuplicatedNoNeed{T,N} <: Annotation{T} val::T dval::NTuple{N,T} + function BatchDuplicatedNoNeed(val::T, dval::NTuple{N,T}; checked=structure_check) where {T, N} + checked && check_congruence(val, dval) + new{T, N}(val, dval) + end end batch_size(::BatchDuplicated{T,N}) where {T,N} = N batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N @@ -113,8 +128,16 @@ Defines values to be congruent, e.g. structurally equivalent. """ function congruent end +congruent(a, b) = false congruent(a::T, b::T) where T<:Number = true -congruent(a::T, b::T) where T<:AbstractArray = length(a) == length(b) +function congruent(a::T, b::T) where T<:DenseArray + axes(a) == axes(b) && all(congruent, zip(a, b)) +end +congruent(a::T, b::T) where T<:Ref = congruent(a[], b[]) +congruent(a::T, b::T) where T<:Tuple = all(congruent, zip(a, b)) +congruent(a::T, b::T) where T<:NamedTuple = all(congruent, zip(a, b)) + +congruent(tup::Tuple{T, T}) where T = congruent(tup...) function check_congruence(a::T, b::T) where T # TODO: Use once hasmethod is static @@ -127,10 +150,18 @@ function check_congruence(a::T, b::T) where T error(""" Your values are not congruent, structural equivalence is requirement for the correctness of the adjoint pass. + + You may need to implement EnzymeCore.congruent(a, b) for your type $T """) end end - + +function check_congruence(a::T, b::NTuple{N, T}) where {N, T} + ntuple(Val(N)) do i + check_congruence(a, b[i]) + end +end + """ abstract type Mode From 1c7f23afad5125f662260d1fe4bf5cc211efed0e Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 12 Mar 2023 13:50:55 -0400 Subject: [PATCH 3/4] fixup! Optional congruence check with Preferences based opt-in --- lib/EnzymeCore/src/EnzymeCore.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f3e480cc18..ce3653cfdb 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -9,6 +9,25 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic const structure_check = parse(Bool, @load_preference("structure_check", "false")) +""" + structure_check!(flag) + +Toggle the default setting for congruence/structure checking. +""" +function structure_check!(flag) + @set_preferences!("structure_check" => flag) + @info("structure_check toggled, restart your Julia session for this change to take effect!") + + if VERSION <= v"1.6.5" || VERSION == v"1.7.0" + @warn """ + Due to a bug in Julia (until 1.6.5 and 1.7.1), setting preferences in transitive dependencies + is broken (https://github.com/JuliaPackaging/Preferences.jl/issues/24). To fix this either update + your version of Julia, or add EnzyemCore as a direct dependency to your project. + """ + end + return nothing +end + function batch_size end """ From fceb201f38c57c9584a737bfd44056dae2e8eeb6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 12 Mar 2023 14:46:17 -0400 Subject: [PATCH 4/4] Apply suggestions from code review --- lib/EnzymeCore/src/EnzymeCore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index ce3653cfdb..703be5cc2b 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -22,7 +22,7 @@ function structure_check!(flag) @warn """ Due to a bug in Julia (until 1.6.5 and 1.7.1), setting preferences in transitive dependencies is broken (https://github.com/JuliaPackaging/Preferences.jl/issues/24). To fix this either update - your version of Julia, or add EnzyemCore as a direct dependency to your project. + your version of Julia, or add EnzymeCore as a direct dependency to your project. """ end return nothing