From 65eaec0c95b391302fe0e8e55058f84aa5374be7 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 30 Oct 2024 15:07:49 -0700 Subject: [PATCH 01/17] Implement recursive_map as basis for make_zero ...as well as recursive_add, recursive_accumulate!, and accumulate_into! --- ext/EnzymeStaticArraysExt.jl | 49 +- lib/EnzymeCore/src/EnzymeCore.jl | 122 +++- src/Enzyme.jl | 8 +- src/analyses/activity.jl | 5 + src/compiler.jl | 2 +- src/internal_rules.jl | 50 +- src/typeutils/make_zero.jl | 587 ------------------- src/typeutils/recursive_add.jl | 186 +++--- src/typeutils/recursive_maps.jl | 769 +++++++++++++++++++++++++ test/Project.toml | 1 + test/make_zero.jl | 725 ----------------------- test/recursive_maps.jl | 951 +++++++++++++++++++++++++++++++ test/runtests.jl | 139 +++-- 13 files changed, 2028 insertions(+), 1566 deletions(-) delete mode 100644 src/typeutils/make_zero.jl create mode 100644 src/typeutils/recursive_maps.jl delete mode 100644 test/make_zero.jl create mode 100644 test/recursive_maps.jl diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index ef955ebd9b..14d18a2835 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,50 +32,11 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} - return Base.zero(prev)::FT -end -@inline function Enzyme.EnzymeCore.make_zero( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - return Base.zero(prev)::FT -end - -@inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} - return Base.zero(prev)::FT -end -@inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - new = Base.zero(prev)::FT - seen[prev] = new - return new -end - -@inline function Enzyme.EnzymeCore.make_zero!( - prev::FT, seen -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end -@inline function Enzyme.EnzymeCore.make_zero!( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - Enzyme.EnzymeCore.make_zero!(prev, nothing) - return nothing +# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct, +# but in case their dedicated `zero` and `fill!` methods are more efficient than +# `make_zero(!)`s recursion, we opt into treating them as leaves. +@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S,T}}) where {S,T} + return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T) end end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f949664b6a..e59534479e 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -506,28 +506,128 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ - make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T - -Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies -what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. + make_zero( + prev::T, ::Val{copy_if_inactive}=Val(false), ::Val{runtime_inactive}=Val(false) + )::T + make_zero( + ::Type{T}, + seen::IdDict, + prev::T, + ::Val{copy_if_inactive}=Val(false), + ::Val{runtime_inactive}=Val(false), + )::T + +Recursively make a copy of the value `prev::T` in which all differentiable values are +zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any +of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s +instance (the default) or make a copy. + +The argument `runtime_inactive` specifies whether each constituent type is checked for being +guaranteed inactive at runtime for every call to `make_zero`, or if this can be checked once +at compile-time and reused across multiple calls to `make_zero` and related functions (the +default). Runtime checks are necessary to pick up recently added methods to +`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually +not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have +previously been passed to `make_zero` or related functions. + +Extending this method for custom types is rarely needed. If you implement a new type, such +as a GPU array type, for which `make_zero` should directly invoke `zero` rather than +iterate/broadcast when the eltype is scalar, it is sufficient to implement `Base.zero` and +make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is not appropriate, +extend [`EnzymeCore.isvectortype`](@ref) directly instead.) """ function make_zero end """ - make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing + make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive}=Val(false))::Nothing + +Recursively set a variable's differentiable values to zero. Only applicable for types `T` +that are mutable or hold all differentiable values in mutable storage (e.g., +`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over +parts of `val` that are guaranteed to be inactive. -Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +The argument `runtime_inactive` specifies whether each constituent type is checked for being +guaranteed inactive at runtime for every call to `make_zero!`, or if this can be checked once +at compile-time and reused across multiple calls to `make_zero!` and related functions (the +default). Runtime checks are necessary to pick up recently added methods to +`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually +not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have +previously been passed to `make_zero!` or related functions. + +Extending this method for custom types is rarely needed. If you implement a new mutable +type, such as a GPU array type, for which `make_zero!` should directly invoke +`fill!(x, false)` rather than iterate/broadcast when the eltype is scalar, it is sufficient +to implement `Base.zero`, `Base.fill!`, and make sure your type subtypes `DenseArray`. (If +subtyping `DenseArray` is not appropriate, extend [`EnzymeCore.isvectortype`](@ref) directly +instead.) """ function make_zero! end """ - make_zero(prev::T) + isvectortype(::Type{T})::Bool -Helper function to recursively make zero. -""" -@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} - make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) +Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref) +and [`make_zero!`](@ref) recurse through an object. + +By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or +`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`. + +A new vector type, such as a GPU array type, should normally subtype `DenseArray` and +inherit `isvectortype` that way. However if this is not appropariate, `isvectortype` may be +extended directly as follows: + +```julia +@inline function EnzymeCore.isvectortype(::Type{T}) where {T<:NewArray} + U = eltype(T) + return isbitstype(U) && EnzymeCore.isscalartype(U) end +``` + +Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. + +Extending `isvectortype` is mostly relevant for the lowest-level of abstraction of memory at +which vector space operations like addition and scalar multiplication are supported, the +prototypical case being `Array`. Regular Julia structs with vector space-like semantics +should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act +directly on their backing arrays, just like how Enzyme treats them when differentiating. For +example, structured matrix wrappers and sparse array types that are backed by `Array` should +not extend `isvectortype`. + +See also [`isscalartype`](@ref). +""" +function isvectortype end + +""" + isscalartype(::Type{T})::Bool + +Trait defining a subset of [`isvectortype`](@ref) types that should not be considered +composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero +values of the type in-place. For example, `BigFloat` is a mutable type but does not support +in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensures that +`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat] + +By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete +`T <: AbstractFloat`. + +A hypothetical new real number type with Enzyme support should usually subtype +`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate, +the function can be extended as follows: + +```julia +@inline EnzymeCore.isscalartype(::Type{NewReal}) = true +@inline EnzymeCore.isscalartype(::Type{Complex{NewReal}}) = true +``` + +In either case, the type should implement `Base.zero`. + +See also [`isvectortype`](@ref). + +[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is +mentioned here only to demonstrate that it would be inappropriate to use traits like +`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing, +showing the need for a dedicated `isscalartype` trait. +""" +function isscalartype end function tape_type end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index e24ff41cdb..17a205a8b1 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -463,12 +463,8 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) # compute the correct complex derivative in reverse mode by propagating the conjugate return values # then subtracting twice the imaginary component to get the correct result - for (k, v) in seen - Compiler.recursive_accumulate(k, v, refn_seed) - end - for (k, v) in seen2 - Compiler.recursive_accumulate(k, v, imfn_seed) - end + Compiler.accumulate_seen!(refn_seed, seen) + Compiler.accumulate_seen!(imfn_seed, seen2) fused = fuse_complex_results(results, args...) diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 61d2f35ab7..2f2b8e6ec8 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -427,6 +427,11 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T} + rt = Enzyme.Compiler.active_reg_inner(T, (), world) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + """ Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) diff --git a/src/compiler.jl b/src/compiler.jl index b61ec5854f..6151ee140d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -315,7 +315,7 @@ const JuliaGlobalNameMap = Dict{String,Any}( include("absint.jl") include("llvm/transforms.jl") include("llvm/passes.jl") -include("typeutils/make_zero.jl") +include("typeutils/recursive_maps.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 04aca1a66a..1fdd874378 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -253,47 +253,6 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end - -@inline function accumulate_into( - into::RT, - seen::IdDict, - from::RT, -)::Tuple{RT,RT} where {RT<:Array} - if Enzyme.Compiler.guaranteed_const(RT) - return (into, from) - end - if !haskey(seen, into) - seen[into] = (into, from) - for i in eachindex(from) - tup = accumulate_into(into[i], seen, from[i]) - @inbounds into[i] = tup[1] - @inbounds from[i] = tup[2] - end - end - return seen[into] -end - -@inline function accumulate_into( - into::RT, - seen::IdDict, - from::RT, -)::Tuple{RT,RT} where {RT<:AbstractFloat} - if !haskey(seen, into) - seen[into] = (into + from, RT(0)) - end - return seen[into] -end - -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT} - if Enzyme.Compiler.guaranteed_const(RT) - return (into, from) - end - if !haskey(seen, into) - throw(AssertionError("Unknown type to accumulate into: $RT")) - end - return seen[into] -end - function EnzymeRules.reverse( config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, @@ -302,15 +261,8 @@ function EnzymeRules.reverse( x::Annotation{Ty}, ) where {RT,Ty} if EnzymeRules.needs_shadow(config) - if EnzymeRules.width(config) == 1 - accumulate_into(x.dval, IdDict(), shadow) - else - for i = 1:EnzymeRules.width(config) - accumulate_into(x.dval[i], IdDict(), shadow[i]) - end - end + Compiler.accumulate_into!(x.dval, shadow) end - return (nothing,) end diff --git a/src/typeutils/make_zero.jl b/src/typeutils/make_zero.jl deleted file mode 100644 index 5c7b49a749..0000000000 --- a/src/typeutils/make_zero.jl +++ /dev/null @@ -1,587 +0,0 @@ -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{FT,N}, -)::Array{FT,N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{Complex{FT},N}, -)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end - - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, FT}, -)::GenericMemory{kind, FT} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, Complex{FT}}, -)::GenericMemory{kind, Complex{FT}} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -end - - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{Array{Complex{FT},N}}, - seen::IdDict, - prev::Array{Complex{FT},N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, FT}}, - seen::IdDict, - prev::GenericMemory{kind, FT}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, FT} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, Complex{FT}}}, - seen::IdDict, - prev::GenericMemory{kind, Complex{FT}}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, Complex{FT}} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{Complex{RT}}, - seen::IdDict, - prev::Complex{RT}, - ::Val{copy_if_inactive} = Val(false), -)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return Complex{RT}(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Array} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:GenericMemory} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Tuple} - return ntuple(length(prev)) do i - Base.@_inline_meta - EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) - end -end - -@inline function EnzymeCore.make_zero( - ::Type{NamedTuple{A,RT}}, - seen::IdDict, - prev::NamedTuple{A,RT}, - ::Val{copy_if_inactive} = Val(false), -)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - prevtup = RT(prev) - TT = Core.Typeof(prevtup) # RT can be abstract - return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) -end - -@inline function EnzymeCore.make_zero( - ::Type{Core.Box}, - seen::IdDict, - prev::Core.Box, - ::Val{copy_if_inactive} = Val(false), -) where {copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) - return res -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT} - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] - end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT - seen[prev] = y - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - if Base.isconst(RT, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) - else - setfield!(y, i, xi) - end - end - end - return y - end - if nf == 0 - return prev - end - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - return zero(T) -end - -function make_zero_immutable!( - prev::Complex{T}, - seen::S, -)::Complex{T} where {T<:AbstractFloat,S} - return zero(Complex{T}) -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} - if guaranteed_const_nongen(T, nothing) - return prev # unreachable from make_zero! - end - ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - p = prev[i] - SBT = Core.Typeof(p) - if guaranteed_const_nongen(SBT, nothing) - p # covered by several tests even if not shown in coverage - elseif !ismutabletype(SBT) - make_zero_immutable!(p, seen) - else - EnzymeCore.make_zero!(p, seen) - p - end - end -end - -function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - if guaranteed_const_nongen(NamedTuple{a,b}, nothing) - return prev # unreachable from make_zero! - end - NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i - Base.@_inline_meta - p = prev[a[i]] - SBT = Core.Typeof(p) - if guaranteed_const_nongen(SBT, nothing) - p # covered by several tests even if not shown in coverage - elseif !ismutabletype(SBT) - make_zero_immutable!(p, seen) - else - EnzymeCore.make_zero!(p, seen) - p - end - end) -end - - -function make_zero_immutable!(prev::T, seen::S)::T where {T,S} - if guaranteed_const_nongen(T, nothing) - return prev # unreachable from make_zero! - end - @assert !ismutabletype(T) - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - ST = Core.Typeof(xi) - flds[i] = if guaranteed_const_nongen(ST, nothing) - xi - elseif !ismutabletype(ST) - make_zero_immutable!(xi, seen) - else - EnzymeCore.make_zero!(xi, seen) - xi - end - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - prev[] = zero(T) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - prev[] = zero(Complex{T}) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(Complex{T})) - return nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(Complex{T})) - return nothing -end -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, -)::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - continue - elseif !ismutabletype(SBT) - @inbounds prev[I] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - end - end - return nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, -)::Nothing where {T<:AbstractFloat, kind} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - continue - elseif !ismutabletype(SBT) - @inbounds prev[I] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - end - end - return nothing -end -end - - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - pv = prev[] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - return nothing - elseif !ismutabletype(SBT) - prev[] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - if prev in seen - return nothing - end - push!(seen, prev) - pv = prev.contents - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - return nothing - elseif !ismutabletype(SBT) - prev.contents = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - if nf == 0 - return nothing - end - push!(seen, prev) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - SBT = Core.Typeof(xi) - activitystate = active_reg_inner(SBT, (), nothing) - if activitystate == AnyState # guaranteed_const - continue - elseif ismutabletype(T) && !ismutabletype(SBT) - yi = make_zero_immutable!(xi, seen) - if Base.isconst(T, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) - else - setfield!(prev, i, yi) - end - elseif activitystate == DupState - EnzymeCore.make_zero!(xi, seen) - else - msg = "cannot set $xi to zero in-place, as it contains differentiable values in immutable positions" - throw(ArgumentError(msg)) - end - end - end - return nothing -end - -@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 039f7d3d0c..b6e12abb39 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -1,86 +1,130 @@ -# Recursively return x + f(y), where y is active, otherwise x - -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T,F,F2} - if forcelhs(T) - return x +using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map! + +""" + recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const) + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recursively construct `z::T` such that `zi = xi + f(yi)` where `zi`, `xi`, and `yi` are +corresponding values from `z`, `x`, and `y`. In other words, this is a recursive +generalization of `x .+ f.(y)`. + +The function `f` must return values of the same type as its argument. + +The optional argument `forcelhs` takes a callable such that if `forcelhs(S) == true`, values +`zi::S` will be set to `zi = xi`. The default returns true for non-differentiable types, +such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies to +non-differentiable values. +""" +function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L} + function addf(xi::S, yi::S) where {S} + @assert EnzymeCore.isvectortype(S) + return (xi + f(yi),)::Tuple{S} end - splatnew(T, ntuple(Val(fieldcount(T))) do i - Base.@_inline_meta - prev = getfield(x, i) - next = getfield(y, i) - recursive_add(prev, next, f, forcelhs) - end) + return only(recursive_map(addf, Val(1), (x, y), Val(false), forcelhs))::T end -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:AbstractFloat,F,F2} - if forcelhs(T) - return x - end - return x + f(y) +""" + accumulate_seen!(f, seen::IdDict, ::Val{runtime_inactive}=Val(false)) + accumulate_seen!(f, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive) + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recursively accumulate from values into keys, generalizing key .+= f.(value), for each +key-value pair in `seen::IdDict` where each key must be a mutable object or non-isbits +vector type instance mappping to another object of the same type and structure. Typically +`seen` is populated by `make_zero` (or some other single-argument invocation of +`recursive_map`), mapping components of its argument to the corresponding component of the +returned value. + +The recursion stops at instances of types that are themselves cached by `make_zero` +(`recursive_map`), as these objects should have their own entries in `seen`. + +Inactive objects that would be shared/copied rather than zeroed by `make_zero` are skipped. +If the optional `::Val{runtime_inactive}` argument was passed to `make_zero`, the same value +should be passed to `accumulate_seen` for consistency. If needed, a custom +`RecursiveMaps.IsInactive` instance can be provided instead. +""" +function accumulate_seen!( + f::F, seen::IdDict, ::Val{runtime_inactive}=Val(false) +) where {F,runtime_inactive} + accumulate_seen!(f, seen, RecursiveMaps.IsInactive{runtime_inactive}()) + return nothing end -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:Complex,F,F2} - if forcelhs(T) - return x +function accumulate_seen!( + f::F, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive +) where {F} + for (k, v) in seen + _accumulate_seen_item!(f, k, v, isinactivetype) end - return x + f(y) + return nothing end -@inline mutable_register(::Type{T}) where {T<:Integer} = true -@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{T}) where {T<:Tuple} = false -@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false -@inline mutable_register(::Type{Core.Box}) = true -@inline mutable_register(::Type{T}) where {T<:Array} = true -@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) - -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} - if !mutable_register(T) - for I in eachindex(x) - prev = x[I] - @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) - end +function _accumulate_seen_item!( + f::F, k::T, v::T, isinactivetype::RecursiveMaps.IsInactive +) where {F,T} + function addf!!(ki::S, vi::S) where {S} + @assert EnzymeCore.isvectortype(S) + return (ki .+ f.(vi),)::Tuple{S} + end + function addf!!(ki::S, _ki::S, vi::S) where {S} + @assert !EnzymeCore.isscalartype(S) + @assert EnzymeCore.isvectortype(S) + @assert ki === _ki + ki .+= f.(vi) + return (ki,)::Tuple{S} + end + RecursiveMaps.check_nonactive(T, isinactivetype) + if !isinactivetype(T) + is_inactive_or_seen_type = RecursiveMaps.IsInactive( + isinactivetype, RecursiveMaps.iscachedtype + ) + newks = RecursiveMaps.recursive_map_inner( + nothing, addf!!, (k,), (k, v), Val(false), is_inactive_or_seen_type + ) + @assert only(newks) === k end + return nothing end +""" + accumulate_into!(into::T, from::T) -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} - recursive_accumulate(x.contents, y.contents, seen, f) -end +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recursively accumulate from `from` into `into` and zero `from`, such that `into_i += from_i` +and `from_i = 0`, where `into_i` and `from_i` are corresponding values within `into` and +`from`. In other words, this is a recursive generalization of -@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - - for i = 1:nf - if isdefined(x, i) - xi = getfield(x, i) - ST = Core.Typeof(xi) - if !mutable_register(ST) - @assert ismutable(x) - yi = getfield(y, i) - nexti = recursive_add(xi, yi, f, mutable_register) - setfield!(x, i, nexti) - end - end +```julia +into .+= from +from .= 0 +``` + +The accumulation and zeroing is only applied to differentiable values; non-differentiable +values within both `into` and `from` are left untouched. +""" +function accumulate_into!(into::T, from::T) where {T} + # may not show in coverage but both base cases are covered via deepcopy custom rule tests + function accumulate_into!!(into_i::S, from_i::S) where {S} + @assert EnzymeCore.isvectortype(S) + return (into_i + from_i, convert(S, zero(from_i)))::Tuple{S,S} + end + function accumulate_into!!(into_i::S, from_i::S, _into_i::S, _from_i::S) where {S} + @assert !EnzymeCore.isscalartype(S) + @assert EnzymeCore.isvectortype(S) + @assert (into_i === _into_i) && (from_i === _from_i) + into_i .+= from_i + fill!(from_i, false) + return (into_i, from_i)::Tuple{S,S} end + recursive_map!(accumulate_into!!, (into, from), (into, from)) + return nothing end diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl new file mode 100644 index 0000000000..245ce22632 --- /dev/null +++ b/src/typeutils/recursive_maps.jl @@ -0,0 +1,769 @@ +module RecursiveMaps + +using EnzymeCore: EnzymeCore, isvectortype, isscalartype +using ..Compiler: Compiler, guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, + guaranteed_nonactive_nongen + +### IsInactive: helper for creating consistent inactive/nonactive type checkers +""" + isinactivetype = IsInactive{runtime::Bool}(extra=(T -> false)) + isinactivetype = IsInactive(isinactivetype::IsInactive, extra) + +!!! warning + Internal type, documented for developer convenience but not covered by semver API + stability guarantees + +Create a callable `isinactivetype` such that `isinactivetype(T) == true` if the type `T` is +non-differentiable, that is, if differentiable values can never be reached from any instance +of the type (that is, the activity state of `T` is `AnyState`). + +The callable takes an optional argument `Val(nonactive::Bool)`, such that the full signature +is + +```julia +isinactivetype(::Type{T}, ::Val{nonactive}=Val(false))::Bool +``` + +Setting `nonactive == true` selects for _nonactive_ types, which is a superset of inactive +types that also includes types `T` where every differentiable value can be mutated without +creating a new instance of `T` (that is, the activity state of `T` is either `AnyState` or +`DupState`). + +The optional argument `extra` takes a function defining additional types that should be +treated as inactive regardless of their nominal activity state; that is, + +```julia +IsInactive{runtime}(extra)(T, args...) == IsInactive{runtime}()(T, args...) || extra(T) +``` + +The constructor `IsInactive(isinactivetype::IsInactive{runtime}, extra)` can be used to +extend an existing instance `isinactivetype::IsInactive` with an additional `extra` +function, and is more or less equivalent to +`IsInactive{runtime}(T -> isinactivetype.extra(T) || extra(T))`. + +The type parameter `runtime` specifies whether the activity state of a type is queried at +runtime every time the callable is invoked (`true`), or if compile-time queries from earlier +calls can be reused (`false`). Runtime querying is necessary to pick up recently added +methods to `EnzymeRules.inactive_type`, but may incur a significant performance penalty and +is usually not needed unless `EnzymeRules.inactive_type` is extended interactively for types +that have previously been passed to an instance of `IsInactive{false}`. +""" +struct IsInactive{runtime,F} + extra::F + function IsInactive{runtime}( + extra::F=(@nospecialize(T) -> (@inline; false)) + ) where {runtime,F} + return new{runtime::Bool,F}(extra) + end +end + +function IsInactive(isinactivetype::IsInactive{runtime}, extra::F) where {runtime,F} + combinedextra(::Type{T}) where {T} = (isinactivetype.extra(T) || extra(T)) + return IsInactive{runtime}(combinedextra) +end + +@inline function (f::IsInactive{runtime,F})( + ::Type{T}, ::Val{nonactive}=Val(false) +) where {runtime,F,T,nonactive} + if runtime + # evaluate f.extra first, as guaranteed_*_nongen may incur runtime dispatch + if nonactive + return f.extra(T) || guaranteed_nonactive_nongen(T, nothing) + else + return f.extra(T) || guaranteed_const_nongen(T, nothing) + end + else + # evaluate guaranteed_* first, as these are always known at compile time + if nonactive + return guaranteed_nonactive(T) || f.extra(T) + else + return guaranteed_const(T) || f.extra(T) + end + end +end + +### traits defining active leaf types for recursive_map +@inline isdensearraytype(::Type{<:DenseArray}) = true +@inline isdensearraytype(::Type) = false + +@inline EnzymeCore.isvectortype(::Type{T}) where {T} = isscalartype(T) +@inline function EnzymeCore.isvectortype(::Type{<:DenseArray{U}}) where {U} + return isbitstype(U) && isscalartype(U) +end + +@inline EnzymeCore.isscalartype(::Type) = false +@inline EnzymeCore.isscalartype(::Type{T}) where {T<:AbstractFloat} = isconcretetype(T) +@inline function EnzymeCore.isscalartype(::Type{Complex{T}}) where {T<:AbstractFloat} + return isconcretetype(T) +end + +### recursive_map: walk arbitrary objects and map a function over scalar and vector leaves +""" + ys = recursive_map( + [seen::Union{Nothing,IdDict},] + f, + ::Val{Nout} + xs::NTuple{Nin,T}, + ::Val{copy_if_inactive}=Val(false), + isinactivetype=IsInactive{false}(), + )::T + newys = recursive_map( + [seen::Union{Nothing,IdDict},] + f, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + ::Val{copy_if_inactive}=Val(false), + isinactivetype=IsInactive{false}(), + )::T + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping the +function `f` over every differentiable value encountered and building `Nout` new objects +`(y1::T, y2::T, ..., yNout::T)` from the resulting values +`(y1_i, ..., yNout_i) = f(x1_i, ..., xNin_i)`. + +The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered +differentiable leaf nodes at which recursion terminates and `f` is invoked. See the +docstring for [`EnzymeCore.isvectortype`](@ref) and the related +[`EnzymeCore.isscalartype`](@ref) for more information. + +A tuple of existing objects `ys = (y1::T, ..., yNout::T)` can be passed, in which case the +`ys` are updated "partially-in-place": any parts of the `ys` that are mutable or +non-differentiable are reused in the returned object tuple `newys`, while immutable +differentiable parts are handled out-of-place as if the `ys` were not passed (this can be +seen as a recursive generalization of the BangBang.jl idiom). If `T` itself is a mutable +type, the `ys` are modified in-place and returned, such that `newys === ys`. + +The recursion and mapping operates on the structure of `T` as defined by struct fields and +plain array elements, not on the values provided through an iteration or array interface. +For example, given a structured matrix wrapper or sparse array type, this function recurses +into the struct type and the plain arrays held within, rather than operating on the array +that the type notionally represents. + +# Arguments + +* `seen::Union{IdDict,Nothing}` (optional): Dictionary for tracking object identity as + needed to construct `y` such that its internal graph of object references is identical to + that of the `xs`, including cycles (i.e., recursive substructures) and multiple paths to + the same objects. If not provided, an `IdDict` will be allocated internally if required. + + If `nothing` is provided, object identity is not tracked. In this case, objects with + multiple references are duplicated such that the `ys`s object reference graph becomes a + tree, cycles lead to infinite recursion and stack overflow, and `copy_if_inactive == true` + will likely cause errors. This is useful only in specific cases. + +* `f`: Function mapping leaf nodes within the `xs` to the corresponding leaf nodes in the + `ys`, that is, `(y1_i, ..., yNout_i) = f(x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}`. + The function `f` must be applicable to the type of every leaf node, and must return a + tuple of values of the same type as its arguments. + + When an existing object tuple `ys` is passed and contains leaf nodes of a non-isbits + non-scalar type `U`, `f` should also have a partially-in-place method + `(newy1_i, ..., newyNout_i) === f(y1_i::U, ..., yNout_i::U, x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}` + that modifies and reuses any mutable parts of the `yj_i`; in particular, if `U` is a + mutable type, this method should return `newyj_i === yj_i`. If a non-isbits type `U` + should always be handled using the out-of-place signature, extend + [`EnzymeCore.isscalartype`](@ref) such that `isscalartype(U) == true`. + + See [`EnzymeCore.isvectortype`](@ref) and [`EnzymeCore.isscalartype`](@ref) for more + details about leaf types and scalar types. + +* `::Val{Nout}` or `ys::NTuple{Nout,T}`: For out-of-place operation, pass `Val(Nout)` where + `Nout` is the length of the tuple returned by `f`, that is, the length of the expected + return value `ys` (this is required; `Nout` never inferred). For partially-in-place + operation, pass the existing tuple `ys::NTuple{Nout,T}` containing the values to be + modified. + +* `xs::NTuple{N,T}`: Tuple of `N` objects of the same type `T` over which `f` is mapped. + + The first object `x1 = first(xs)` is the reference for graph structure and + non-differentiable values when constructing the returned object. In particular: + * When `ys` is not passed, the returned objects take any non-differentiable parts from + `x1`. (When `ys` is passed, its non-differentiable parts are kept unchanged in the + returned object, unless they are not initialized, in which case they are taken from + `x1`.) + * The graph of object references in `x1` is the one which is reproduced in the returned + object. For each instance of multiple paths and cycles within `x1`, the same structure + must be present in the other objects `x2, ..., xN`, otherwise the corresponding values + in the `ys` would not be uniquely defined. However, `x2, ..., xN` may contain multiple + paths or cycles that are not present in `x1`; these do not affect the structure of `ys`. + * If any values within `x1` are not initialized (that is, struct fields are undefined or + array elements are unassigned), they are left uninitialized in the returned object. If + any such values are mutable and `ys` is passed, the corresponding value in `y` must not + already be initialized, since initialized values cannot be nulled. Conversely, for every + value in `x1` that is initialized, the corresponding values in `x2, ..., xN` must also + be initialized, such that the corresponding values of the `ys` can be computed (however, + values in `x2, ..., xN` can be initialized while the corresponding value in `x1` is not; + such values are ignored.) + +* `::Val{copy_if_inactive::Bool}` (optional): When a non-differentiable part of `x1` is + included in the returned object, either because an object tuple `ys` is not passed or this + part of the `ys` is not initialized, `copy_if_inactive` determines how: if + `copy_if_inactive == false`, it is shared as `yj_i = x1_i`; if `copy_if_inactive == true`, + it is deep-copied, more-or-less as `yj_i = deepcopy(x1_i)` (the difference is that when + `x1` has several non-differentiable parts, object identity is tracked across the multiple + deep-copies such that the object reference graph is reproduced also within the inactive + parts.) + +* `isinactivetype` (optional): Callable determining which types are considered inactive and + thus treated according to `copy_if_inactive`. The [`IsInactive`](@ref) type is a + convenient helper for obtaining a callable with relevant semantics, but any callable that + maps types to `true` or `false` can be used. +""" +function recursive_map end + +## type alias for unified handling of out-of-place and partially-in-place recursive_map +const YS{Nout,T} = Union{Val{Nout},NTuple{Nout,T}} +@inline hasvalues(::Val{Nout}) where {Nout} = (Nout::Int; false) +@inline hasvalues(::NTuple) = true + +## main entry point: set default arguments, allocate IdDict if needed, exit early if possible +function recursive_map( + f::F, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val{false}, + isinactivetype::L=IsInactive{false}(), +) where {F,Nout,Nin,T,L} + newys = if isinactivetype(T) + recursive_map_inactive(nothing, ys, xs, copy_if_inactive) + elseif isvectortype(T) || isbitstype(T) + recursive_map_inner(nothing, f, ys, xs, copy_if_inactive, isinactivetype) + else + recursive_map_inner(IdDict(), f, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +## recursive methods +function recursive_map( + seen::Union{Nothing,IdDict}, + f::F, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val{false}, + isinactivetype::L=IsInactive{false}(), +) where {F,Nout,Nin,T,L} + # determine whether to continue recursion, copy/share, or retrieve from cache + newys = if isinactivetype(T) + recursive_map_inactive(seen, ys, xs, copy_if_inactive) + elseif isbitstype(T) # no object identity to to track in this branch + recursive_map_inner(nothing, f, ys, xs, copy_if_inactive, isinactivetype) + elseif hascache(seen, xs) + getcached(seen, Val(Nout), xs) + else + recursive_map_inner(seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_inner( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + # forward to appropriate handler for leaf vs. mutable vs. immutable type + @assert !isabstracttype(T) + @assert isconcretetype(T) + newys = if isvectortype(T) + recursive_map_leaf(seen, f, ys, xs) + elseif ismutabletype(T) + recursive_map_mutable(seen, f, ys, xs, copy_if_inactive, isinactivetype) + else + recursive_map_immutable(seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +@generated function recursive_map_mutable( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + @assert ismutabletype(T) + iteration_i = quote + @inbounds if isinitialized(x1, i) + check_allinitialized(xtail, i) + newys_i = recursive_map_index(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) + setitems!(newys, i, newys_i) + elseif hasvalues(ys) + check_allinitialized(ys, i, false) + end + end + return quote + @inline + if !hasvalues(ys) && !isdensearraytype(T) && all(isbitstype, fieldtypes(T)) + # fast path for out-of-place handling when all fields are bitstypes, which rules + # out undefined fields and circular references + newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + maybecache!(seen, newys, xs) + else + x1, xtail = first(xs), Base.tail(xs) + newys = if hasvalues(ys) + ys + else + Base.@ntuple $Nout _ -> _similar(x1) + end + maybecache!(seen, newys, xs) + if isdensearraytype(T) + if (Nout == 1) && isbitstype(eltype(T)) + recursive_map_broadcast!( + f, newys, ys, xs, copy_if_inactive, isinactivetype + ) + else + for i in eachindex(newys..., xs...) + $iteration_i + end + end + else # unrolled loop over struct fields + Base.Cartesian.@nexprs $(fieldcount(T)) i -> $iteration_i + end + end + return newys::NTuple{Nout,T} + end +end + +@generated function recursive_map_immutable( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + @assert !ismutabletype(T) + nf = fieldcount(T) + return quote + @inline + if $nf == 0 # nothing to do (also no known way to hit this branch) + newys = recursive_map_inactive(nothing, ys, xs, Val(false)) + else + x1, xtail = first(xs), Base.tail(xs) + if isinitialized(x1, $nf) # fast path when all fields are defined + check_allinitialized(xtail, $nf) + newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + else + Base.Cartesian.@nexprs $Nout j -> (fields_j = Vector{Any}(undef, $(nf - 1))) + Base.Cartesian.@nexprs $(nf - 1) i -> begin # unrolled loop over struct fields + @inbounds if isinitialized(x1, i) + check_allinitialized(xtail, i) + newys_i = recursive_map_index( + i, seen, f, ys, xs, copy_if_inactive, isinactivetype + ) + Base.Cartesian.@nexprs $Nout j -> (fields_j[i] = newys_i[j]) + else + ndef = i - 1 # rest of tail must be undefined values + @goto done # break out of unrolled loop + end + end + ndef = $(nf - 1) # loop didn't break, only last field is undefined + @label done + newys = Base.@ntuple $Nout j -> begin + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, fields_j, ndef)::T + end + end + # maybecache! _should_ be a no-op here; call it anyway for consistency + maybecache!(seen, newys, xs) + end + return newys::NTuple{Nout,T} + end +end + +@generated function recursive_map_new( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + # direct construction of fully initialized non-cyclic structs + nf = fieldcount(T) + return quote + @inline + Base.Cartesian.@nexprs $nf i -> begin + newys_i = @inbounds recursive_map_index( + i, seen, f, ys, xs, copy_if_inactive, isinactivetype + ) + end + newys = Base.@ntuple $Nout j -> begin + $(Expr(:splatnew, :T, :(Base.@ntuple $nf i -> newys_i[j]))) + end + return newys::NTuple{Nout,T} + end +end + +@inline function recursive_map_broadcast!( + f::F, newys::NTuple{1,T}, ys::YS{1,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nin,T,L} + # broadcast recursive_map over array-like inputs with isbits elements + @assert isdensearraytype(T) + @assert isbitstype(eltype(T)) + newy = first(newys) + if hasvalues(ys) + @assert newys === ys + broadcast!( + (newy_i, xs_i...) -> first(recursive_map_barrier!!( + nothing, f, copy_if_inactive, isinactivetype, Val(1), newy_i, xs_i... + )), + newy, + newy, + xs..., + ) + else + broadcast!( + (xs_i...,) -> first(recursive_map_barrier( + nothing, f, copy_if_inactive, isinactivetype, Val(1), xs_i... + )), + newy, + xs..., + ) + end + return nothing +end + +Base.@propagate_inbounds function recursive_map_index( + i, seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + # recurse into the xs and apply recursive_map to items with index i + xs_i = getitems(xs, i) + newys_i = if hasvalues(ys) && isinitialized(first(ys), i) + check_allinitialized(Base.tail(ys), i) + ys_i = getitems(ys, i) + recursive_map_barrier!!( + seen, f, copy_if_inactive, isinactivetype, Val(Nout), ys_i..., xs_i... + ) + else + recursive_map_barrier(seen, f, copy_if_inactive, isinactivetype, Val(Nout), xs_i...) + end + return newys_i +end + +# function barriers such that abstractly typed items trigger minimal runtime dispatch +function recursive_map_barrier( + seen, f::F, copy_if_inactive::Val, isinactivetype::L, ::Val{Nout}, xs_i::Vararg{ST,Nin} +) where {F,Nout,Nin,ST,L} + return recursive_map( + seen, f, Val(Nout), xs_i, copy_if_inactive, isinactivetype + )::NTuple{Nout,ST} +end + +function recursive_map_barrier!!( # TODO: hit this when VectorSpace implemented + seen, f::F, copy_if_inactive, isinactivetype::L, ::Val{Nout}, yxs_i::Vararg{ST,M} +) where {F,Nout,M,ST,L} + ys_i, xs_i = yxs_i[1:(Nout::Int)], yxs_i[((Nout::Int)+1):end] + return recursive_map( + seen, f, ys_i, xs_i, copy_if_inactive, isinactivetype + )::NTuple{Nout,ST} +end + +# specialized methods to optimize the common cases Nout == 1 and Nout == 2 +function recursive_map_barrier!!( + seen, f::F, copy_if_inactive::Val, isinactivetype::L, ::Val{1}, yi::ST, xs_i::Vararg{ST,Nin} +) where {F,Nin,ST,L} + return recursive_map( + seen, f, (yi,), xs_i, copy_if_inactive, isinactivetype + )::NTuple{1,ST} +end + +function recursive_map_barrier!!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + seen, + f::F, + copy_if_inactive::Val, + isinactivetype::L, + ::Val{2}, + y1_i::ST, + y2_i::ST, + xs_i::Vararg{ST,Nin} +) where {F,Nin,ST,L} + return recursive_map( + seen, f, (y1_i, y2_i), xs_i, copy_if_inactive, isinactivetype + )::NTuple{2,ST} +end + +## recursion base case handlers +@inline function recursive_map_leaf( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T} +) where {F,Nout,Nin,T} + # apply the mapped function to leaf values + newys = if !hasvalues(ys) || isbitstype(T) || isscalartype(T) + f(xs...)::NTuple{Nout,T} + else # !isbitstype(T) + newys_ = f(ys..., xs...)::NTuple{Nout,T} + if ismutabletype(T) + @assert newys_ === ys + end + newys_ + end + maybecache!(seen, newys, xs) + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_inactive( + _, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive} +) where {Nout,Nin,T,copy_if_inactive} + return ys::NTuple{Nout,T} +end + +@generated function recursive_map_inactive( + seen, ::Val{Nout}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive} +) where {Nout,Nin,T,copy_if_inactive} + return quote + @inline + y = if copy_if_inactive && !isbitstype(T) + Base.deepcopy_internal(first(xs), isnothing(seen) ? IdDict() : seen) + else + first(xs) + end + return (Base.@ntuple $Nout _ -> y)::NTuple{Nout,T} + end +end + +### recursive_map!: fully in-place wrapper around recursive_map +""" + recursive_map!( + [seen::IdDict,] + f!!, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + ::Val{copy_if_inactive}=Val(false), + isinactivetype::IsInactive=IsInactive{false}(), + )::Nothing + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping +the function `f!!` over every differentiable value encountered and updating +`(y1::T, y2::T, ..., yNout::T)`` in-place with the resulting values. + +This is a simple wrapper that verifies that `T` is a type where all differentiable values +can be updated in-place (this uses the `nonactive == true` mode of `isinactivetype`, see +[`IsInactive`](@ref) for details), calls `recursive_map`, and verifies that the returned +value is indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. + +Note that this wrapper only supports instances of [`IsInactive`](@ref) for the +`isinactivetype` argument, as this is the only way we can insure consistency between the +upfront compatibility check and actual behavior. If this is not appropriate, use +`recursive_map` directly. +""" +function recursive_map!( + f!!::F, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val(false), + isinactivetype::IsInactive=IsInactive{false}(), +) where {F,Nout,Nin,T} + check_nonactive(T, isinactivetype) + newys = recursive_map(f!!, ys, xs, copy_if_inactive, isinactivetype) + @assert newys === ys + return nothing +end + +function recursive_map!( + seen::IdDict, + f!!::F, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val(false), + isinactivetype::IsInactive=IsInactive{false}(), +) where {F,Nout,Nin,T} + check_nonactive(T, isinactivetype) + newys = recursive_map(seen, f!!, ys, xs, copy_if_inactive, isinactivetype) + @assert newys === ys + return nothing +end + +### recursive_map helpers +@inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T +@inline _similar(x::T) where {T<:DenseArray} = similar(x)::T +Base.@propagate_inbounds isinitialized(x, i) = isdefined(x, i) +Base.@propagate_inbounds isinitialized(x::DenseArray, i) = isassigned(x, i) +Base.@propagate_inbounds getitem(x, i) = getfield(x, i) +Base.@propagate_inbounds getitem(x::DenseArray, i) = x[i] +Base.@propagate_inbounds setitem!(x, i, v) = setfield_force!(x, i, v) +Base.@propagate_inbounds setitem!(x::DenseArray, i, v) = (x[i] = v; nothing) + +Base.@propagate_inbounds function setfield_force!(x::T, i, v) where {T} + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i - 1, v) + else + setfield!(x, i, v) + end + return nothing +end + +Base.@propagate_inbounds function getitems(xs::Tuple{T,T,Vararg{T,N}}, i) where {T,N} + return (getitem(first(xs), i), getitems(Base.tail(xs), i)...) +end + +Base.@propagate_inbounds getitems(xs::Tuple{T}, i) where {T} = (getitem(only(xs), i),) + +Base.@propagate_inbounds function setitems!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + xs::Tuple{T,T,Vararg{T,N}}, i, vs::Tuple{ST,ST,Vararg{ST,N}} +) where {T,ST,N} + setitem!(first(xs), i, first(vs)) + setitems!(Base.tail(xs), i, Base.tail(vs)) + return nothing +end + +Base.@propagate_inbounds function setitems!(xs::Tuple{T}, i, vs::Tuple{ST}) where {T,ST} + setitem!(only(xs), i, only(vs)) + return nothing +end + +## cache (seen) helpers +@inline function iscachedtype(::Type{T}) where {T} + # cache all mutable types and any non-isbits types that are also leaf types + return ismutabletype(T) || ((!isbitstype(T)) && isvectortype(T)) +end + +@inline shouldcache(::IdDict, ::Type{T}) where {T} = iscachedtype(T) +@inline shouldcache(::Nothing, ::Type{T}) where {T} = false + +@inline function maybecache!(seen, newys::NTuple{Nout,T}, xs::NTuple{Nin,T}) where {Nout,Nin,T} + if shouldcache(seen, T) + if (Nout == 1) && (Nin == 1) + seen[only(xs)] = only(newys) + else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + seen[first(xs)] = (newys..., Base.tail(xs)...) + end + end + return nothing +end + +@inline function hascache(seen, xs::NTuple{Nin,T}) where {Nin,T} + return shouldcache(seen, T) ? haskey(seen, first(xs)) : false +end + +@inline function getcached(seen::IdDict, ::Val{Nout}, xs::NTuple{Nin,T}) where {Nout,Nin,T} + newys = if (Nout == 1) && (Nin == 1) + (seen[only(xs)]::T,) + else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + cache = seen[first(xs)]::NTuple{(Nout + Nin - 1),T} + cachedtail = cache[(Nout+1):end] + check_identical(cachedtail, Base.tail(xs)) # check compatible layout + cache[1:Nout] + end + return newys::NTuple{Nout,T} +end + +## argument validation +Base.@propagate_inbounds function check_initialized(x, i, initialized=true) + if isinitialized(x, i) != initialized + throw_initialized() # TODO: hit this when VectorSpace implemented + end + return nothing +end + +Base.@propagate_inbounds function check_allinitialized( # TODO: hit this when VectorSpace implemented + xs::Tuple{T,T,Vararg{T,N}}, i, initialized=true +) where {T,N} + check_initialized(first(xs), i, initialized) + check_allinitialized(Base.tail(xs), i, initialized) + return nothing +end + +Base.@propagate_inbounds function check_allinitialized( + xs::Tuple{T}, i, initialized=true +) where {T} + check_initialized(only(xs), i, initialized) + return nothing +end + +Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized=true) = nothing + +@inline function check_identical(u, v) # TODO: hit this when VectorSpace implemented + if u !== v + throw_identical() + end + return nothing +end + +@inline function check_nonactive(::Type{T}, isinactivetype::IsInactive) where {T} + if !isinactivetype(T, Val(true)) #=nonactive=# + throw_nonactive() + end + return nothing +end + +# TODO: hit all of these via check_* when VectorSpace implemented +@noinline function throw_initialized() + msg = "recursive_map(!) called on objects whose undefined fields/unassigned elements " + msg *= "don't line up" + throw(ArgumentError(msg)) +end + +@noinline function throw_identical() + msg = "recursive_map(!) called on objects whose layout don't match" + throw(ArgumentError(msg)) +end + +@noinline function throw_nonactive() + msg = "recursive_map! called on objects containing immutable differentiable values" + throw(ArgumentError(msg)) +end + +### EnzymeCore.make_zero(!) implementation +function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} + new = if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback + # IsInactive has precedence over isvectortype for consistency with recursive handler + convert(T, zero(prev)) # convert because zero(prev)::T may fail when eltype(T) is abstract + else + _make_zero_inner(prev, args...) + end + return new::T +end + +function EnzymeCore.make_zero!(val::T, args::Vararg{Any,M}) where {T,M} + @assert !isscalartype(T) # not appropriate for in-place handler + if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback + # IsInactive has precedence over isvectortype for consistency with recursive handler + fill!(val, false) + else + _make_zero_inner!(val, args...) + end + return nothing +end + +@inline function _make_zero_inner( + prev::T, copy_if_inactive::Val=Val(false), ::Val{runtime_inactive}=Val(false) +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + news = recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) + return only(news)::T +end + +@inline function _make_zero_inner!( + val::T, ::Val{runtime_inactive}=Val(false) +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + recursive_map!(_make_zero!!, (val,), (val,), Val(false), isinactivetype) + return nothing +end + +@inline function _make_zero_inner!( + val::T, seen::IdDict, ::Val{runtime_inactive}=Val(false) +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + recursive_map!(seen, _make_zero!!, (val,), (val,), Val(false), isinactivetype) + return nothing +end + +function _make_zero!!(prev::T) where {T} + @assert isvectortype(T) # otherwise infinite loop + return (EnzymeCore.make_zero(prev),)::Tuple{T} +end + +function _make_zero!!(val::T, _val::T) where {T} + @assert !isscalartype(T) # not appropriate for in-place handler + @assert isvectortype(T) # otherwise infinite loop + @assert val === _val + EnzymeCore.make_zero!(val) + return (val,)::Tuple{T} +end + +# alternative entry point for passing custom IdDict +function EnzymeCore.make_zero( + ::Type{T}, + seen::IdDict, + prev::T, + copy_if_inactive::Val=Val(false), + ::Val{runtime_inactive}=Val(false), +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + news = recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) + return only(news)::T +end + +end # module RecursiveMaps diff --git a/test/Project.toml b/test/Project.toml index fbc6d754fe..667d94ba1f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/make_zero.jl b/test/make_zero.jl deleted file mode 100644 index cbe2f2159f..0000000000 --- a/test/make_zero.jl +++ /dev/null @@ -1,725 +0,0 @@ -module MakeZeroTests - -using Enzyme -using StaticArrays -using Test - -# Universal getters/setters for built-in and custom containers/wrappers -getx(w::Base.RefValue) = w[] -getx(w::Core.Box) = w.contents -getx(w) = first(w) -gety(w) = last(w) - -setx!(w::Base.RefValue, x) = (w[] = x) -setx!(w::Core.Box, x) = (w.contents = x) -setx!(w, x) = (w[begin] = x) -sety!(w, y) = (w[end] = y) - -# non-isbits MArray doesn't support setindex!, so requires a little hack -function setx!(w::MArray{S,T}, x) where {S,T} - if isbitstype(T) - w[begin] = x - else - w.data = (x, Base.tail(w.data)...) - end - return x -end - -function sety!(w::MArray{S,T}, y) where {S,T} - if isbitstype(T) - w[end] = y - else - w.data = (Base.front(w.data)..., y) - end - return y -end - -struct Empty end - -mutable struct MutableEmpty end - -Base.:(==)(::MutableEmpty, ::MutableEmpty) = true - -struct Wrapper{T} - x::T -end - -Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) -getx(a::Wrapper) = a.x - -mutable struct MutableWrapper{T} - x::T -end - -Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) - -getx(a::MutableWrapper) = a.x -setx!(a::MutableWrapper, x) = (a.x = x) - -struct DualWrapper{Tx,Ty} - x::Tx - y::Ty -end - -DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) - -function Base.:(==)(a::DualWrapper, b::DualWrapper) - return (a === b) || ((a.x == b.x) && (a.y == b.y)) -end - -getx(a::DualWrapper) = a.x -gety(a::DualWrapper) = a.y - -mutable struct MutableDualWrapper{Tx,Ty} - x::Tx - y::Ty -end - -MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) - -function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) - return (a === b) || ((a.x == b.x) && (a.y == b.y)) -end - -getx(a::MutableDualWrapper) = a.x -gety(a::MutableDualWrapper) = a.y - -setx!(a::MutableDualWrapper, x) = (a.x = x) -sety!(a::MutableDualWrapper, y) = (a.y = y) - -struct Incomplete{T} - s::String - x::Float64 - w::T - z # not initialized - Incomplete(s, x, w) = new{typeof(w)}(s, x, w) -end - -function Base.:(==)(a::Incomplete, b::Incomplete) - (a === b) && return true - ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false - if isdefined(a, :z) && isdefined(b, :z) - (a.z == b.z) || return false - elseif isdefined(a, :z) || isdefined(b, :z) - return false - end - return true -end - -mutable struct MutableIncomplete{T} - s::String - const x::Float64 - y::Float64 - z # not initialized - w::T - function MutableIncomplete(s, x, y, w) - ret = new{typeof(w)}(s, x, y) - ret.w = w - return ret - end -end - -function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) - (a === b) && return true - if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) - return false - end - if isdefined(a, :z) && isdefined(b, :z) - (a.z == b.z) || return false - elseif isdefined(a, :z) || isdefined(b, :z) - return false - end - return true -end - -mutable struct CustomVector{T} <: AbstractVector{T} - data::Vector{T} -end - -Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) - -function Enzyme.EnzymeCore.make_zero( - ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} -) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} - @info "make_zero(::CustomVector)" - if haskey(seen, prev) - return seen[prev] - end - new = CustomVector(zero(prev.data))::CV - seen[prev] = new - return new -end - -function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing - @info "make_zero!(::CustomVector)" - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev.data, false) - return nothing -end - -function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) - return Enzyme.EnzymeCore.make_zero!(prev, nothing) -end - -struct WithIO{F} # issue 2091 - v::Vector{Float64} - callback::F - function WithIO(v, io) - callback() = println(io, "hello") - return new{typeof(callback)}(v, callback) - end -end - -macro test_noerr(expr) - return quote - @test_nowarn try - # catch errors to get failed test instead of "exception outside of a @test" - $(esc(expr)) - catch e - showerror(stderr, e) - end - end -end - -const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] - -const inactivetup = ("a", Empty(), MutableEmpty()) -const inactivearr = [inactivetup] - -const wrappers = [ - (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), - (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), - (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), - - (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), - (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), - - (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), - (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), - (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), - - (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), - (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), - (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), - (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), - - (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), - (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), - (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), - - (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), - (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), - - (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), - (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), - - (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), - - (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), - (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), - - (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), - (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), - - # StaticArrays extension - (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), - (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), - (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), - (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), - (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), - (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), - (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), - (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), -] - -@static if VERSION < v"1.11-" -else -_memory(x::Vector) = Memory{eltype(x)}(x) -push!( - wrappers, - (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), - (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), - (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), - (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), -) -end - -function test_make_zero() - @testset "scalars" begin - @testset "$T" for T in scalartypes - x = oneunit(T) - x_makez = make_zero(x) - @test typeof(x_makez) === T # correct type - @test x_makez == zero(T) # correct value - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - end - end - @testset "nested types" begin - @testset "$T in $(wrapper.name)" for - T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) - x = oneunit(T) - w = wrapper.f(x) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(getx(w_makez)) === T # correct type - @test getx(w_makez) == zero(T) # correct value - @test getx(w) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - @testset "doubly included in $(dualwrapper.name)" for - dualwrapper in filter(w -> (w.N == 2), wrappers) - w_inner = wrapper.f(x) - d_outer = dualwrapper.f(w_inner, w_inner) - d_outer_makez = make_zero(d_outer) - @test typeof(d_outer_makez) === typeof(d_outer) # correct type - @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type - @test typeof(getx(getx(d_outer_makez))) === T # correct type - @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology - @test getx(getx(d_outer_makez)) == zero(T) # correct value - @test getx(d_outer) === gety(d_outer) # no mutation of original - @test getx(d_outer) === w_inner # no mutation of original - @test getx(w_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - d_inner = dualwrapper.f(x, x) - w_outer = wrapper.f(d_inner) - w_outer_makez = make_zero(w_outer) - @test typeof(w_outer_makez) === typeof(w_outer) # correct type - @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type - @test typeof(getx(getx(w_outer_makez))) === T # correct type - @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology - @test getx(getx(w_outer_makez)) == zero(T) # correct value - @test getx(w_outer) === d_inner # no mutation of original - @test getx(d_inner) === gety(d_inner) # no mutation of original - @test getx(d_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - if wrapper.mutable && !dualwrapper.mutable - # some code paths can only be hit with three layers of wrapping: - # mutable(immutable(mutable(scalar))) - @testset "all wrapped in $(outerwrapper.name)" for - outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) - w_inner = wrapper.f(x) - d_middle = dualwrapper.f(w_inner, w_inner) - w_outer = outerwrapper.f(d_middle) - w_outer_makez = make_zero(w_outer) - @test typeof(w_outer_makez) === typeof(w_outer) # correct type - @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type - @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type - @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type - @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology - @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value - @test getx(w_outer) === d_middle # no mutation of original - @test getx(d_middle) === gety(d_middle) # no mutation of original - @test getx(d_middle) === w_inner # no mutation of original - @test getx(w_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - end - end - end - end - end - @testset "inactive" begin - @testset "in $(wrapper.name)" for wrapper in wrappers - if wrapper.N == 1 - w = wrapper.f(inactivearr) - w_makez = make_zero(w) - if wrapper.typed == true - @test w_makez === w # preserved wrapper identity if guaranteed const - end - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - @test getx(w) === inactivearr # no mutation of original - else # wrapper.N == 2 - @testset "multiple references" begin - w = wrapper.f(inactivearr, inactivearr) - w_makez = make_zero(w) - if wrapper.typed == true - @test w_makez === w # preserved wrapper identity if guaranteed const - end - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === gety(w_makez) # preserved topology - @test getx(w_makez) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - @test getx(w) === gety(w) # no mutation of original - @test getx(w) === inactivearr # no mutation of original - end - @testset "alongside active" begin - a = [1.0] - w = wrapper.f(a, inactivearr) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(getx(w_makez)) === typeof(a) # correct type - @test getx(w_makez) == [0.0] # correct value - @test gety(w_makez) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - @test getx(w) === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - @test gety(w) === inactivearr # no mutation of original - if wrapper.typed == :partial - # above: untyped active / typed inactive - # below: untyped inactive / typed active - w = wrapper.f(inactivearr, a) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - @test typeof(gety(w_makez)) === typeof(a) # correct type - @test gety(w_makez) == [0.0] # correct value - @test getx(w) === inactivearr # no mutation of original - @test gety(w) === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - end - end - @testset "copy_if_inactive $value" for (value, args) in [ - ("unspecified", ()), - ("= false", (Val(false),)), - ("= true", (Val(true),)), - ] - a = [1.0] - w = Any[a, inactivearr, inactivearr] - w_makez = make_zero(w, args...) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(w_makez[1]) === typeof(a) # correct type - @test w_makez[1] == [0.0] # correct value - @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) - @test w[1] === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - @test w[2] === w[3] # no mutation of original - @test w[2] === inactivearr # no mutation of original - @test inactivearr[1] === inactivetup # no mutation of original - if args == (Val(true),) - @test typeof(w_makez[2]) === typeof(inactivearr) # correct type - @test w_makez[2] == inactivearr # correct value - @test w_makez[2][1] !== inactivetup # correct identity - else - @test w_makez[2] === inactivearr # correct value/type/identity - end - end - end - @testset "heterogeneous containers" begin - scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) - wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) - mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) - items = (inactivetup..., scalars..., wraps..., mwraps...) - itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) - labels = Symbol.("i" .* string.(1:length(items))) - @testset "$name" for (name, c, cz) in [ - ("Tuple", Tuple(items), Tuple(itemsz)), - ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), - ("Array", collect(items), collect(itemsz)), - ] - c_makez = make_zero(c) - @test typeof(c_makez) === typeof(c) # correct type - @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type - @test c_makez == cz # correct value - @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities - @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original - @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original - end - end - @testset "circular references" begin - @testset "$(wrapper.name)" for wrapper in ( - filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) - ) - a = [1.0] - if wrapper.N == 1 - w = wrapper.f(nothing) - setx!(w, (w, a)) - else - w = wrapper.f(nothing, a) - setx!(w, w) - end - w_makez = @test_noerr make_zero(w) - if wrapper.N == 1 - xz, yz = getx(w_makez) - x, y = getx(w) - else - xz, yz = getx(w_makez), gety(w_makez) - x, y = getx(w), gety(w) - end - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(xz) === typeof(w) # correct type - @test typeof(yz) === typeof(a) # correct type - @test xz === w_makez # correct self-reference - @test yz == [0.0] # correct value - @test x === w # no mutation of original - @test y === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - @testset "bring your own IdDict" begin - a = [1.0] - seen = IdDict() - a_makez = make_zero(typeof(a), seen, a) - @test typeof(a_makez) === typeof(a) # correct type - @test a_makez == [0.0] # correct value - @test a[1] === 1.0 # no mutation of original - @test haskey(seen, a) # original added to IdDict - @test seen[a] === a_makez # original points to zeroed value - end - @testset "custom leaf type" begin - a = [1.0] - v = CustomVector(a) - # include optional arg Val(false) to avoid calling the custom method directly; - # it should still be invoked - v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) - @test typeof(v_makez) === typeof(v) # correct type - @test typeof(v_makez.data) === typeof(a) # correct type - @test v_makez == CustomVector([0.0]) # correct value - @test v.data === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - @testset "undefined fields/unassigned elements" begin - @testset "array w inactive/active/mutable/unassigned" begin - a = [1.0] - values = ("a", 1.0, a) - arr = Vector{Any}(undef, 4) - arr[1:3] .= values - arr_makez = make_zero(arr) - @views begin - @test typeof(arr_makez) === typeof(arr) # correct type - @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type - @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value - @test !isassigned(arr_makez, 4) # propagated undefined - @test all(arr[1:3] .=== values) # no mutation of original - @test !isassigned(arr, 4) # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - @testset "struct w inactive/active/mutable/undefined" begin - a = [1.0] - incomplete = Incomplete("a", 1.0, a) - incomplete_makez = make_zero(incomplete) - @test typeof(incomplete_makez) === typeof(incomplete) # correct type - @test typeof(incomplete_makez.w) === typeof(a) # correct type - @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined - @test a[1] === 1.0 # no mutation of original - end - @testset "mutable struct w inactive/const active/active/mutable/undefined" begin - a = [1.0] - incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) - incomplete_makez = make_zero(incomplete) - @test typeof(incomplete_makez) === typeof(incomplete) # correct type - @test typeof(incomplete_makez.w) === typeof(a) # correct type - @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined - @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original - @test incomplete.w === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - @testset "containing IO" begin # issue #2091 - f = WithIO([1.0, 2.0], stdout) - df = @test_noerr make_zero(f) - @test df.v == [0.0, 0.0] - @test df.callback === f.callback - end - return nothing -end - -function test_make_zero!() - @testset "nested types" begin - @testset "$T in $(wrapper.name)" for - T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) - x = oneunit(T) - if wrapper.mutable - w = wrapper.f(x) - make_zero!(w) - @test typeof(getx(w)) === T # preserved type - @test getx(w) == zero(T) # correct value - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - end - @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( - filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) - ) - w_inner = wrapper.f(x) - d_outer = dualwrapper.f(w_inner, w_inner) - make_zero!(d_outer) - @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type - @test typeof(getx(getx(d_outer))) === T # preserved type - @test getx(getx(d_outer)) == zero(T) # correct value - @test getx(d_outer) === gety(d_outer) # preserved topology - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - if wrapper.mutable - @test getx(d_outer) === w_inner # preserved identity - end - d_inner = dualwrapper.f(x, x) - w_outer = wrapper.f(d_inner) - make_zero!(w_outer) - @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type - @test typeof(getx(getx(w_outer))) === T # preserved type - @test getx(getx(w_outer)) == zero(T) # correct value - @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - if dualwrapper.mutable - @test getx(w_outer) === d_inner # preserved identity - end - if wrapper.mutable && !dualwrapper.mutable - # some code paths can only be hit with three layers of wrapping: - # mutable(immutable(mutable(scalar))) - @assert !dualwrapper.mutable # sanity check - @testset "all wrapped in $(outerwrapper.name)" for - outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) - w_inner = wrapper.f(x) - d_middle = dualwrapper.f(w_inner, w_inner) - w_outer = outerwrapper.f(d_middle) - make_zero!(w_outer) - @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type - @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type - @test typeof(getx(getx(getx(w_outer)))) === T # preserved type - @test getx(getx(getx(w_outer))) == zero(T) # correct value - @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology - @test getx(getx(w_outer)) === w_inner # preserved identity - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - end - end - end - end - end - @testset "inactive" begin - @testset "in $(wrapper.name)" for - wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) - if wrapper.N == 1 - w = wrapper.f(inactivearr) - make_zero!(w) - @test getx(w) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - else # wrapper.N == 2 - @testset "multiple references" begin - w = wrapper.f(inactivearr, inactivearr) - make_zero!(w) - @test getx(w) === gety(w) # preserved topology - @test getx(w) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - end - @testset "alongside active" begin - a = [1.0] - w = wrapper.f(a, inactivearr) - make_zero!(w) - @test getx(w) === a # preserved identity - @test a[1] === 0.0 # correct value - @test gety(w) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - end - end - end - end - @testset "heterogeneous containers" begin - mwraps = MutableWrapper.(oneunit.(scalartypes)) - mwrapsz = MutableWrapper.(zero.(scalartypes)) - items = (inactivetup..., mwraps...) - itemsz = (inactivetup..., mwrapsz...) - labels = Symbol.("i" .* string.(1:length(items))) - @testset "$name" for (name, c, cz) in [ - ("Tuple", Tuple(items), Tuple(itemsz)), - ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), - ("Array", collect(items), collect(itemsz)), - ] - make_zero!(c) - @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities - @test c == cz # correct value - end - end - @testset "circular references" begin - @testset "$(wrapper.name)" for wrapper in ( - filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) - ) - a = [1.0] - if wrapper.N == 1 - w = wrapper.f(nothing) - setx!(w, (w, a)) - else - w = wrapper.f(nothing, a) - setx!(w, w) - end - @test_noerr make_zero!(w) - if wrapper.N == 1 - x, y = getx(w) - else - x, y = getx(w), gety(w) - end - @test x === w # preserved self-referential identity - @test y === a # preserved identity - @test a[1] === 0.0 # correct value - end - end - @testset "bring your own IdSet" begin - a = [1.0] - seen = Base.IdSet() - make_zero!(a, seen) - @test a[1] === 0.0 # correct value - @test (a in seen) # object added to IdSet - end - @testset "custom leaf type" begin - a = [1.0] - v = CustomVector(a) - # bringing own IdSet to avoid calling the custom method directly; - # it should still be invoked - @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) - @test v.data === a # preserved identity - @test a[1] === 0.0 # correct value - end - @testset "undefined fields/unassigned elements" begin - @testset "array w inactive/active/mutable/unassigned" begin - a = [1.0] - values = ("a", 1.0, a) - arr = Vector{Any}(undef, 4) - arr[1:3] .= values - make_zero!(arr) - @views begin - @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types - @test arr[1:3] == ["a", 0.0, [0.0]] # correct value - @test arr[3] === a # preserved identity - @test !isassigned(arr, 4) # preserved unassigned - end - end - @testset "struct w inactive/active/mutable/undefined" begin - a = [1.0] - incompletearr = [Incomplete("a", 1.0, a)] - make_zero!(incompletearr) - @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined - @test incompletearr[1].w === a # preserved identity - end - @testset "mutable struct w inactive/const active/active/mutable/undefined" begin - a = [1.0] - incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) - make_zero!(incomplete) - @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined - @test incomplete.w === a # preserved identity - end - @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin - # old implementation triggered #1935 - # new implementation would work regardless due to limited use of justActive - a = [1.0] - incomplete = Incomplete("a", 1.0, a) - incompletetuparr = [(incomplete,)] - make_zero!(incompletetuparr) - @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type - @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value - @test incompletetuparr[1][1].w === a # preserved identity - end - end - @testset "active/mixed type error" begin - @test_throws ArgumentError make_zero!((1.0,)) - @test_throws ArgumentError make_zero!((1.0, [1.0])) - @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 - end - @testset "containing IO" begin # issue #2091 - f = WithIO([1.0, 2.0], stdout) - fwrapped = [f] - @test_noerr make_zero!(fwrapped) - @test fwrapped[1] === f - @test fwrapped[1].v == [0.0, 0.0] - end - return nothing -end - -@testset "make_zero" test_make_zero() -@testset "make_zero!" test_make_zero!() - -end # module MakeZeroTests diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl new file mode 100644 index 0000000000..fc6684d1e5 --- /dev/null +++ b/test/recursive_maps.jl @@ -0,0 +1,951 @@ +module RecursiveMapTests + +using Enzyme +using JLArrays +using Logging +using StaticArrays +using Test + +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w::JLArray) = JLArrays.@allowscalar first(w) +gety(w::JLArray) = JLArrays.@allowscalar last(w) +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S,T}, x) where {S,T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S,T}, y) where {S,T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + +mutable struct MutableWrapper{T} + x::T +end + +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) + +struct Incomplete{T,U} + s::String + x::Float64 + w::T + y::U # possibly not initialized + z # not initialized + Incomplete(s, x, w) = new{typeof(w),Any}(s, x, w) + Incomplete(s, x, w, y) = new{typeof(w),typeof(y)}(s, x, w, y) +end + +function Base.:(==)(a::Incomplete, b::Incomplete) + (a === b) && return true + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false + if isdefined(a, :y) && isdefined(b, :y) + (a.w == b.w) || return false + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + elseif isdefined(a, :y) || isdefined(b, :y) + return false + end + return true +end + +mutable struct MutableIncomplete{T} + s::String + const x::Float64 + y::Float64 + z # not initialized + w::T + function MutableIncomplete(s, x, y, w) + ret = new{typeof(w)}(s, x, y) + ret.w = w + return ret + end +end + +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) + (a === b) && return true + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) + return false + end + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct CustomVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.isvectortype(::Type{CustomVector{T}}) where {T} + return Enzyme.EnzymeCore.isscalartype(T) +end + +function Enzyme.EnzymeCore.make_zero(prev::CV) where {CV<:CustomVector{<:AbstractFloat}} + @info "make_zero(::CustomVector)" + return CustomVector(zero(prev.data))::CV +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + @info "make_zero!(::CustomVector)" + fill!(prev.data, false) + return nothing +end + +struct WithIO{F} # issue 2091 + v::Vector{Float64} + callback::F + function WithIO(v, io) + callback() = println(io, "hello") + return new{typeof(callback)}(v, callback) + end +end + +macro test_noerr(expr) + return quote + @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + $(esc(expr)) + catch e + showerror(stderr, e) + end + end +end + +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64, BigFloat, Complex{BigFloat}] + +const inactivebits = (1, Empty()) +const inactivetup = (inactivebits, "a", MutableEmpty()) +const inactivearr = [inactivetup] + +const wrappers = [ + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true, bitsonly=false), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true, bitsonly=false), + + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false, bitsonly=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false, bitsonly=false), + + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true, bitsonly=false), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true, bitsonly=false), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true, bitsonly=false), + + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false, bitsonly=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false, bitsonly=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false, bitsonly=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false, bitsonly=false), + + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true, bitsonly=false), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true, bitsonly=false), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true, bitsonly=false), + + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial, bitsonly=false), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial, bitsonly=false), + + (name="@NamedTuple{x,y}", f=(@NamedTuple{x,y} ∘ tuple), N=2, mutable=false, typed=false, bitsonly=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false, bitsonly=false), + + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true, bitsonly=false), + + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted, bitsonly=false), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial, bitsonly=false), + + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false, bitsonly=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false, bitsonly=false), + + # StaticArrays extension + (name="SVector{1,X}", f=(SVector{1} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), + (name="SVector{1,Any}", f=(SVector{1,Any} ∘ tuple), N=1, mutable=false, typed=false, bitsonly=false), + (name="MVector{1,X}", f=(MVector{1} ∘ tuple), N=1, mutable=true, typed=true, bitsonly=false), + (name="MVector{1,Any}", f=(MVector{1,Any} ∘ tuple), N=1, mutable=true, typed=false, bitsonly=false), + (name="SVector{2,promote_type(X,Y)}", f=(SVector{2} ∘ tuple), N=2, mutable=false, typed=:promoted, bitsonly=false), + (name="SVector{2,Any}", f=(SVector{2,Any} ∘ tuple), N=2, mutable=false, typed=false, bitsonly=false), + (name="MVector{2,promote_type(X,Y)}", f=(MVector{2} ∘ tuple), N=2, mutable=true, typed=:promoted, bitsonly=false), + (name="MVector{2,Any}", f=(MVector{2,Any} ∘ tuple), N=2, mutable=true, typed=false, bitsonly=false), + + # GPUArrays extension + (name="JLArray{X}", f=(x -> JLArray([x])), N=1, mutable=true, typed=true, bitsonly=true), + (name="JLArray{promote_type(X,Y)}", f=((x, y) -> JLArray([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=true), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true, bitsonly=false), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false, bitsonly=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=false), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false, bitsonly=false), +) +end + +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( + w -> (w.N == 1), wrappers + ) + (!wrapper.bitsonly || isbitstype(T)) || continue + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in filter( + w -> (w.N == 2), wrappers + ) + (!dualwrapper.bitsonly || isbitstype(T)) || continue + w_inner = wrapper.f(x) + if !dualwrapper.bitsonly || isbits(w_inner) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct layout + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + d_inner = dualwrapper.f(x, x) + if !wrapper.bitsonly || isbits(d_inner) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct layout + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + if wrapper.mutable && !dualwrapper.mutable && !dualwrapper.bitsonly + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct layout + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive) + w_makez = make_zero(w) + if wrapper.typed in (true, :promoted) + if w isa JLArray # needs JLArray activity + @test_broken w_makez === w + else + @test w_makez === w # preserved wrapper identity if guaranteed const + end + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactive # preserved identity + @test getx(w) === inactive # no mutation of original + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + @testset "mixed" begin + for (inactive, mixed, condition) in [ + (inactivebits, (1.0, inactivebits), true), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(mixed) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(mixed) # correct type + @test getx(w_makez)[1] === 0.0 # correct value + @test getx(w_makez)[2] === inactive # preserved inactive identity + @test getx(w) === mixed # no mutation of original + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved inactive value + @test mixed[1] === 1.0 # no mutation of original + @test mixed[2] === inactivearr # no mutation of original + end + end + end + else # wrapper.N == 2 + @testset "multiple references" begin + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive, inactive) + w_makez = make_zero(w) + if wrapper.typed in (true, :promoted) + if w isa JLArray # needs JLArray activity + @test_broken w_makez === w + else + @test w_makez === w # preserved wrapper identity if guaranteed const + end + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved layout + @test getx(w_makez) === inactive # preserved identity + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactive # no mutation of original + if inactive === inactive + @test inactivearr[1] === inactivetup # preserved value + end + end + end + if !wrapper.bitsonly + @testset "mixed" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + end + @testset "copy_if_inactive $value" for (value, args) in [ + ("unspecified", ()), + ("= false", (Val(false),)), + ("= true", (Val(true),)), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct layout (layout should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if args == (Val(true),) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + end + @testset "heterogeneous float arrays" begin + b1r, b2r = big"1.0", big"2.0" + b1i, b2i = big"1.0" * im, big"2.0" * im + ar = AbstractFloat[1.0f0, 1.0, b1r, b1r, b2r] + ai = Complex{<:AbstractFloat}[1.0f0im, 1.0im, b1i, b1i, b2i] + for (a, btype) in [(ar, typeof(b1r)), (ai, typeof(b1i))] + a_makez = make_zero(a) + @test a_makez[1] === zero(a[1]) + @test a_makez[2] === zero(a[2]) + @test typeof(a_makez[3]) === btype + @test a_makez[3] == 0 + @test a_makez[4] === a_makez[3] + @test typeof(a_makez[5]) === btype + @test a_makez[5] == 0 + @test a_makez[5] !== a_makez[3] + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in filter( + w -> (w.mutable && (w.typed in (:partial, false))), wrappers + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_noerr make_zero(w) + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === a_makez # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "runtime inactive" begin + a = [1.0] + v = CustomVector(a) + with_logger(SimpleLogger(Warn)) do # silence @info "make_zero(::CustomVector)" + # ensure compile-time methods are evaluated while CustomVector is considered active + @assert !EnzymeRules.inactive_type(CustomVector) + v_makez = make_zero(v, Val(false), Val(false)) + @assert v_makez == CustomVector([0.0]) + + # verify that runtime methods also see CustomVector as active + v_makez = make_zero(v, Val(false), Val(true)) + @test v_makez == CustomVector([0.0]) + + # mark CustomVector as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = true + + # runtime_inactive == false => redefined inactive_type should have no effect + v_makez = @invokelatest make_zero(v, Val(false), Val(false)) + @test v_makez == CustomVector([0.0]) + + # runtime_inactive == true => redefined inactive_type should take effect: + # CustomVector considered inactive and won't be zeroed, but + # shared/copied according to copy_if_inactive instead + v_makez = @invokelatest make_zero(v, Val(false), Val(true)) + @test v_makez === v + v_makez = @invokelatest make_zero(v, Val(true), Val(true)) + @test v_makez !== v + @test v_makez == CustomVector([1.0]) + + # mark CustomVector as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = false + + # verify that both compile-time and runtime methods see CustomVector as active + v_makez = @invokelatest make_zero(v, Val(false), Val(false)) + @test v_makez == CustomVector([0.0]) + v_makez = @invokelatest make_zero(v, Val(false), Val(true)) + @test v_makez == CustomVector([0.0]) + end + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + @testset "single undefined" begin + incomplete = Incomplete("a", 1.0, a, nothing) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0], nothing) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "multiple undefined" begin + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback + end + return nothing +end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( + w -> (w.N == 1), wrappers + ) + (!wrapper.bitsonly || isbitstype(T)) || continue + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + (!dualwrapper.bitsonly || isbitstype(T)) || continue + w_inner = wrapper.f(x) + if !dualwrapper.bitsonly || isbits(w_inner) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved layout + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + end + d_inner = dualwrapper.f(x, x) + if !wrapper.bitsonly || isbits(d_inner) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved layout + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + end + if wrapper.mutable && !dualwrapper.mutable && !dualwrapper.bitsonly + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved layout + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in filter( + w -> (w.mutable || (w.typed == true)), wrappers + ) + if wrapper.N == 1 + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive) + make_zero!(w) + @test getx(w) === inactive # preserved identity + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + @testset "mixed" begin + for (inactive, mixed, condition) in [ + (inactivebits, (1.0, inactivebits), wrapper.mutable), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(mixed) + make_zero!(w) + @test getx(w)[1] === 0.0 + @test getx(w)[2] === inactive + if inactive === inactivearr + @test getx(w) === mixed # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + end + end + else # wrapper.N == 2 + @testset "multiple references" begin + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive, inactive) + make_zero!(w) + @test getx(w) === gety(w) # preserved layout + @test getx(w) === inactive # preserved identity + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + end + if !wrapper.bitsonly + @testset "mixed" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + end + @testset "heterogeneous float arrays" begin + b1r, b2r = big"1.0", big"2.0" + b1i, b2i = big"1.0" * im, big"2.0" * im + ar = AbstractFloat[1.0f0, 1.0, b1r, b1r, b2r] + ai = Complex{<:AbstractFloat}[1.0f0im, 1.0im, b1i, b1i, b2i] + for (a, btype) in [(ar, typeof(b1r)), (ai, typeof(b1i))] + a1, a2 = a[1], a[2] + make_zero!(a) + @test a[1] === zero(a1) + @test a[2] === zero(a2) + @test typeof(a[3]) === btype + @test a[3] == 0 + @test a[4] === a[3] + @test typeof(a[5]) === btype + @test a[5] == 0 + @test a[5] !== a[3] + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in filter( + w -> (w.mutable && (w.typed in (:partial, false))), wrappers + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_noerr make_zero!(w) + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test haskey(seen, a) # object added to IdDict + @test seen[a] === a # object points to zeroed value, i.e., itself + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdDict to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, IdDict()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "runtime inactive" begin + a = [1.0] + v = CustomVector(a) + with_logger(SimpleLogger(Warn)) do # silence @info "make_zero!(::CustomVector)" + # ensure compile-time methods are evaluated while CustomVector is considered active + @assert !EnzymeRules.inactive_type(CustomVector) + make_zero!(v, Val(false)) + @assert v == CustomVector([0.0]) + + # verify that runtime methods also see CustomVector as active + v.data[1] = 1.0 + make_zero!(v, Val(true)) + @test v == CustomVector([0.0]) + + # mark CustomVector as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = true + + # runtime_inactive == false => compile-time methods still used, redefined + # inactive_type should have no effect + v.data[1] = 1.0 + @invokelatest make_zero!(v, Val(false)) + @test v == CustomVector([0.0]) + + # runtime_inactive == true => redefined inactive_type should take effect + # CustomVector considered inactive and won't be zeroed + v.data[1] = 1.0 + @invokelatest make_zero!(v, Val(true)) + @test v == CustomVector([1.0]) + + # mark CustomVector as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = false + + # verify that both compile-time and runtime methods see CustomVector as active + v.data[1] = 1.0 + @invokelatest make_zero!(v, Val(false)) + @test v == CustomVector([0.0]) + v.data[1] = 1.0 + @invokelatest make_zero!(v, Val(true)) + @test v == CustomVector([0.0]) + end + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation of make_zero! triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] + end + return nothing +end + +end # module RecursiveMapTests + +@testset "make_zero" RecursiveMapTests.test_make_zero() +@testset "make_zero!" RecursiveMapTests.test_make_zero!() diff --git a/test/runtests.jl b/test/runtests.jl index 5c5d70d9fa..dc8e583788 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,7 +75,7 @@ include("abi.jl") include("typetree.jl") include("passes.jl") include("optimize.jl") -include("make_zero.jl") +include("recursive_maps.jl") include("rules.jl") include("rrules.jl") @@ -440,6 +440,25 @@ make3() = (1.0, 2.0, 3.0) da = [2.7] @test autodiff(Forward, sumdeepcopy, Duplicated(a, da))[1] ≈ 2.7 + # Nested containers to test nontrivial recursion in deepcopy reverse rule + b = [[3.14]] + db = [[0.0]] + sumdeepcopy_nested(x) = sum(sum, deepcopy(x)) + autodiff(Reverse, sumdeepcopy_nested, Duplicated(b, db)) + @test db[1][1] ≈ 1.0 + + c_inner = [3.14] + dc_inner = [0.0] + c = [c_inner, c_inner] + dc = [dc_inner, dc_inner] + autodiff(Reverse, sumdeepcopy_nested, Duplicated(c, dc)) + @test dc[1] === dc[2] + @test dc[1][1] ≈ 2.0 + + d = [(3.14,)] + dd = [(0.0,)] + autodiff(Reverse, sumdeepcopy_nested, Duplicated(d, dd)) + @test dd[1][1] ≈ 1.0 end @testset "Deferred and deferred thunk" begin @@ -533,94 +552,70 @@ end @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active, Active(z)) @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sum, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 1.0 + function reverse_holomorphic_array_tests( + f, val, dval_expected; val_expected=val, ret=Active, mapf=true + ) + vals = ComplexF64[val] + dvals = ComplexF64[zero(val)] + autodiff(ReverseHolomorphic, f, ret, Duplicated(vals, dvals)) + @test vals[1] ≈ val_expected + @test dvals[1] ≈ dval_expected - sumsq(x) = sum(x .* x) + # Use tuple to test out-of-place accumulate_seen! base case + tvals = [(ComplexF64(val),)] + dtvals = [(ComplexF64(zero(val)),)] + ft = mapf ? v -> first(map(f, v)) : f + autodiff(ReverseHolomorphic, ft, ret, Duplicated(tvals, dtvals)) + @test tvals[1][1] ≈ val_expected + @test dtvals[1][1] ≈ dval_expected + end - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sum" reverse_holomorphic_array_tests(sum, 3.4 + 2.7im, 1.0) + + sumsq(x) = sum(x .* x) + @testset "sumsq" reverse_holomorphic_array_tests(sumsq, 3.4 + 2.7im, 2(3.4 + 2.7im)) sumsq2(x) = sum(abs2.(x)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq2, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sumsq2" reverse_holomorphic_array_tests(sumsq2, 3.4 + 2.7im, 2(3.4 + 2.7im)) sumsq2C(x) = Complex{Float64}(sum(abs2.(x))) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq2C, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 3.4 - 2.7im - - sumsq3(x) = sum(x .* conj(x)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq3, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 3.4 - 2.7im - - sumsq3R(x) = Float64(sum(x .* conj(x))) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq3R, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sumsq2C" reverse_holomorphic_array_tests(sumsq2C, 3.4 + 2.7im, 3.4 - 2.7im) + + sumsq3(x) = sum(x .* conj.(x)) + @testset "sumsq3" reverse_holomorphic_array_tests(sumsq3, 3.4 + 2.7im, 3.4 - 2.7im) + + sumsq3R(x) = Float64(sum(x .* conj.(x))) + @testset "sumsq3R" reverse_holomorphic_array_tests(sumsq3R, 3.4 + 2.7im, 2(3.4 + 2.7im)) function setinact(z) - z[1] *= 2 + z[1] = 2 .* z[1] # works for both [x] and [(x,)] nothing end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - + @testset "setinact" reverse_holomorphic_array_tests( + setinact, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + ) function setinact2(z) - z[1] *= 2 + z[1] = 2 .* z[1] # works for both [x] and [(x,)] return 0.0+1.0im end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact2, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact2, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - + @testset "setinact2 Const" reverse_holomorphic_array_tests( + setinact2, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + ) + @testset "setinact2 Active" reverse_holomorphic_array_tests( + setinact2, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Active, mapf=false + ) function setact(z) - z[1] *= 2 - return z[1] + z[1] = 2 .* z[1] # works for both [x] and [(x,)] + return z[1][1] # returns scalar for both [x] and [(x,)] end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setact, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setact, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 2.0 + @testset "setact Const" reverse_holomorphic_array_tests( + setact, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + ) + @testset "setact Active" reverse_holomorphic_array_tests( + setact, 3.4 + 2.7im, 2.0; val_expected=2(3.4 + 2.7im), ret=Active, mapf=false + ) function upgrade(z) z = ComplexF64(z) From 5a5715c8c76465cbcc9c5c287b0cd0c3f718b020 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 9 Jan 2025 21:14:09 -0800 Subject: [PATCH 02/17] Minor tweaks to typeasserts and accumulate_seen ...and a nonsense default argument --- src/typeutils/recursive_add.jl | 22 +++++++++++----------- src/typeutils/recursive_maps.jl | 10 +++++----- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index b6e12abb39..0bf4a16ed0 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -21,7 +21,7 @@ non-differentiable values. function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L} function addf(xi::S, yi::S) where {S} @assert EnzymeCore.isvectortype(S) - return (xi + f(yi),)::Tuple{S} + return ((xi + f(yi))::S,) end return only(recursive_map(addf, Val(1), (x, y), Val(false), forcelhs))::T end @@ -59,33 +59,33 @@ end function accumulate_seen!( f::F, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive ) where {F} + isinactivetype_or_seen = RecursiveMaps.IsInactive( + isinactivetype, RecursiveMaps.iscachedtype + ) for (k, v) in seen - _accumulate_seen_item!(f, k, v, isinactivetype) + _accumulate_seen_item!(f, k, v, isinactivetype, isinactivetype_or_seen) end return nothing end function _accumulate_seen_item!( - f::F, k::T, v::T, isinactivetype::RecursiveMaps.IsInactive + f::F, k::T, v::T, isinactivetype, isinactivetype_or_seen ) where {F,T} function addf!!(ki::S, vi::S) where {S} @assert EnzymeCore.isvectortype(S) - return (ki .+ f.(vi),)::Tuple{S} + return ((ki .+ f.(vi))::S,) end function addf!!(ki::S, _ki::S, vi::S) where {S} @assert !EnzymeCore.isscalartype(S) @assert EnzymeCore.isvectortype(S) @assert ki === _ki ki .+= f.(vi) - return (ki,)::Tuple{S} + return (ki::S,) end RecursiveMaps.check_nonactive(T, isinactivetype) if !isinactivetype(T) - is_inactive_or_seen_type = RecursiveMaps.IsInactive( - isinactivetype, RecursiveMaps.iscachedtype - ) newks = RecursiveMaps.recursive_map_inner( - nothing, addf!!, (k,), (k, v), Val(false), is_inactive_or_seen_type + nothing, addf!!, (k,), (k, v), Val(false), isinactivetype_or_seen ) @assert only(newks) === k end @@ -115,7 +115,7 @@ function accumulate_into!(into::T, from::T) where {T} # may not show in coverage but both base cases are covered via deepcopy custom rule tests function accumulate_into!!(into_i::S, from_i::S) where {S} @assert EnzymeCore.isvectortype(S) - return (into_i + from_i, convert(S, zero(from_i)))::Tuple{S,S} + return ((into_i + from_i)::S, convert(S, zero(from_i))::S) end function accumulate_into!!(into_i::S, from_i::S, _into_i::S, _from_i::S) where {S} @assert !EnzymeCore.isscalartype(S) @@ -123,7 +123,7 @@ function accumulate_into!(into::T, from::T) where {T} @assert (into_i === _into_i) && (from_i === _from_i) into_i .+= from_i fill!(from_i, false) - return (into_i, from_i)::Tuple{S,S} + return (into_i::S, from_i::S) end recursive_map!(accumulate_into!!, (into, from), (into, from)) return nothing diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index 245ce22632..bad7083b41 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -1,7 +1,7 @@ module RecursiveMaps using EnzymeCore: EnzymeCore, isvectortype, isscalartype -using ..Compiler: Compiler, guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, +using ..Compiler: guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, guaranteed_nonactive_nongen ### IsInactive: helper for creating consistent inactive/nonactive type checkers @@ -225,7 +225,7 @@ function recursive_map( f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val{false}, + copy_if_inactive::Val=Val(false), isinactivetype::L=IsInactive{false}(), ) where {F,Nout,Nin,T,L} newys = if isinactivetype(T) @@ -244,7 +244,7 @@ function recursive_map( f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val{false}, + copy_if_inactive::Val=Val(false), isinactivetype::L=IsInactive{false}(), ) where {F,Nout,Nin,T,L} # determine whether to continue recursion, copy/share, or retrieve from cache @@ -742,7 +742,7 @@ end function _make_zero!!(prev::T) where {T} @assert isvectortype(T) # otherwise infinite loop - return (EnzymeCore.make_zero(prev),)::Tuple{T} + return (EnzymeCore.make_zero(prev)::T,) end function _make_zero!!(val::T, _val::T) where {T} @@ -750,7 +750,7 @@ function _make_zero!!(val::T, _val::T) where {T} @assert isvectortype(T) # otherwise infinite loop @assert val === _val EnzymeCore.make_zero!(val) - return (val,)::Tuple{T} + return (val::T,) end # alternative entry point for passing custom IdDict From a477b51b88007bf0ed3da80985a71566a6a40b93 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 9 Jan 2025 21:15:37 -0800 Subject: [PATCH 03/17] Simplify: remove IsInactive, always use *_nongen --- lib/EnzymeCore/src/EnzymeCore.jl | 51 +++----- src/analyses/activity.jl | 4 +- src/typeutils/recursive_add.jl | 46 +++----- src/typeutils/recursive_maps.jl | 194 ++++++------------------------- test/recursive_maps.jl | 128 ++++++++------------ 5 files changed, 120 insertions(+), 303 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index e59534479e..6829f88fea 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -506,60 +506,35 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ - make_zero( - prev::T, ::Val{copy_if_inactive}=Val(false), ::Val{runtime_inactive}=Val(false) - )::T - make_zero( - ::Type{T}, - seen::IdDict, - prev::T, - ::Val{copy_if_inactive}=Val(false), - ::Val{runtime_inactive}=Val(false), - )::T + make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T + make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T Recursively make a copy of the value `prev::T` in which all differentiable values are zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s instance (the default) or make a copy. -The argument `runtime_inactive` specifies whether each constituent type is checked for being -guaranteed inactive at runtime for every call to `make_zero`, or if this can be checked once -at compile-time and reused across multiple calls to `make_zero` and related functions (the -default). Runtime checks are necessary to pick up recently added methods to -`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually -not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have -previously been passed to `make_zero` or related functions. - Extending this method for custom types is rarely needed. If you implement a new type, such -as a GPU array type, for which `make_zero` should directly invoke `zero` rather than -iterate/broadcast when the eltype is scalar, it is sufficient to implement `Base.zero` and -make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is not appropriate, -extend [`EnzymeCore.isvectortype`](@ref) directly instead.) +as a GPU array type, on which `make_zero` should directly invoke `zero` when the eltype is +scalar, it is sufficient to implement `Base.zero` and make sure your type subtypes +`DenseArray`. (If subtyping `DenseArray` is not appropriate, extend +[`EnzymeCore.isvectortype`](@ref) instead.) """ function make_zero end """ - make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive}=Val(false))::Nothing + make_zero!(val::T, [seen::IdDict])::Nothing Recursively set a variable's differentiable values to zero. Only applicable for types `T` that are mutable or hold all differentiable values in mutable storage (e.g., `Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over parts of `val` that are guaranteed to be inactive. -The argument `runtime_inactive` specifies whether each constituent type is checked for being -guaranteed inactive at runtime for every call to `make_zero!`, or if this can be checked once -at compile-time and reused across multiple calls to `make_zero!` and related functions (the -default). Runtime checks are necessary to pick up recently added methods to -`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually -not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have -previously been passed to `make_zero!` or related functions. - Extending this method for custom types is rarely needed. If you implement a new mutable -type, such as a GPU array type, for which `make_zero!` should directly invoke -`fill!(x, false)` rather than iterate/broadcast when the eltype is scalar, it is sufficient -to implement `Base.zero`, `Base.fill!`, and make sure your type subtypes `DenseArray`. (If -subtyping `DenseArray` is not appropriate, extend [`EnzymeCore.isvectortype`](@ref) directly -instead.) +type, such as a GPU array type, on which `make_zero!` should directly invoke +`fill!(x, false)` when the eltype is scalar, it is sufficient to implement `Base.zero`, +`Base.fill!`, and make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is +not appropriate, extend [`EnzymeCore.isvectortype`](@ref) instead.) """ function make_zero! end @@ -569,7 +544,7 @@ function make_zero! end Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref) and [`make_zero!`](@ref) recurse through an object. -By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or +By default, `isvectortype(T) == true` when `isscalartype(T) == true` or when `T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`. A new vector type, such as a GPU array type, should normally subtype `DenseArray` and @@ -607,7 +582,7 @@ in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensure `make_zero!` will not try to mutate `BigFloat` values.[^BigFloat] By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete -`T <: AbstractFloat`. +types `T <: AbstractFloat`. A hypothetical new real number type with Enzyme support should usually subtype `AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate, diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 2f2b8e6ec8..8fa5331d77 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -414,7 +414,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_c return res end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world)::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world=nothing)::Bool where {T} rt = active_reg_inner(T, (), world) res = rt == AnyState return res @@ -427,7 +427,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world=nothing)::Bool where {T} rt = Enzyme.Compiler.active_reg_inner(T, (), world) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 0bf4a16ed0..6de0ca910b 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -1,7 +1,7 @@ using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map! """ - recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const) + recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const_nongen) !!! warning Internal function, documented for developer convenience but not covered by semver API @@ -18,7 +18,9 @@ The optional argument `forcelhs` takes a callable such that if `forcelhs(S) == t such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies to non-differentiable values. """ -function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L} +function recursive_add( + x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const_nongen +) where {T,F,L} function addf(xi::S, yi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((xi + f(yi))::S,) @@ -27,8 +29,7 @@ function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) end """ - accumulate_seen!(f, seen::IdDict, ::Val{runtime_inactive}=Val(false)) - accumulate_seen!(f, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive) + accumulate_seen!(f, seen::IdDict) !!! warning Internal function, documented for developer convenience but not covered by semver API @@ -42,35 +43,17 @@ vector type instance mappping to another object of the same type and structure. returned value. The recursion stops at instances of types that are themselves cached by `make_zero` -(`recursive_map`), as these objects should have their own entries in `seen`. - -Inactive objects that would be shared/copied rather than zeroed by `make_zero` are skipped. -If the optional `::Val{runtime_inactive}` argument was passed to `make_zero`, the same value -should be passed to `accumulate_seen` for consistency. If needed, a custom -`RecursiveMaps.IsInactive` instance can be provided instead. +(`recursive_map`), as these objects should have their own entries in `seen`. The recursion +also stops at inactive objects that not be zeroed by `make_zero`. """ -function accumulate_seen!( - f::F, seen::IdDict, ::Val{runtime_inactive}=Val(false) -) where {F,runtime_inactive} - accumulate_seen!(f, seen, RecursiveMaps.IsInactive{runtime_inactive}()) - return nothing -end - -function accumulate_seen!( - f::F, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive -) where {F} - isinactivetype_or_seen = RecursiveMaps.IsInactive( - isinactivetype, RecursiveMaps.iscachedtype - ) +function accumulate_seen!(f::F, seen::IdDict) where {F} for (k, v) in seen - _accumulate_seen_item!(f, k, v, isinactivetype, isinactivetype_or_seen) + _accumulate_seen_item!(f, k, v) end return nothing end -function _accumulate_seen_item!( - f::F, k::T, v::T, isinactivetype, isinactivetype_or_seen -) where {F,T} +function _accumulate_seen_item!(f::F, k::T, v::T) where {F,T} function addf!!(ki::S, vi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((ki .+ f.(vi))::S,) @@ -82,10 +65,13 @@ function _accumulate_seen_item!( ki .+= f.(vi) return (ki::S,) end - RecursiveMaps.check_nonactive(T, isinactivetype) - if !isinactivetype(T) + @inline function isinactive_or_cachedtype(::Type{T}) where {T} + return guaranteed_const_nongen(T) || RecursiveMaps.iscachedtype(T) + end + RecursiveMaps.check_nonactive(T) + if !guaranteed_const_nongen(T) newks = RecursiveMaps.recursive_map_inner( - nothing, addf!!, (k,), (k, v), Val(false), isinactivetype_or_seen + nothing, addf!!, (k,), (k, v), Val(false), isinactive_or_cachedtype ) @assert only(newks) === k end diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index bad7083b41..04c8998b87 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -1,86 +1,7 @@ module RecursiveMaps using EnzymeCore: EnzymeCore, isvectortype, isscalartype -using ..Compiler: guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, - guaranteed_nonactive_nongen - -### IsInactive: helper for creating consistent inactive/nonactive type checkers -""" - isinactivetype = IsInactive{runtime::Bool}(extra=(T -> false)) - isinactivetype = IsInactive(isinactivetype::IsInactive, extra) - -!!! warning - Internal type, documented for developer convenience but not covered by semver API - stability guarantees - -Create a callable `isinactivetype` such that `isinactivetype(T) == true` if the type `T` is -non-differentiable, that is, if differentiable values can never be reached from any instance -of the type (that is, the activity state of `T` is `AnyState`). - -The callable takes an optional argument `Val(nonactive::Bool)`, such that the full signature -is - -```julia -isinactivetype(::Type{T}, ::Val{nonactive}=Val(false))::Bool -``` - -Setting `nonactive == true` selects for _nonactive_ types, which is a superset of inactive -types that also includes types `T` where every differentiable value can be mutated without -creating a new instance of `T` (that is, the activity state of `T` is either `AnyState` or -`DupState`). - -The optional argument `extra` takes a function defining additional types that should be -treated as inactive regardless of their nominal activity state; that is, - -```julia -IsInactive{runtime}(extra)(T, args...) == IsInactive{runtime}()(T, args...) || extra(T) -``` - -The constructor `IsInactive(isinactivetype::IsInactive{runtime}, extra)` can be used to -extend an existing instance `isinactivetype::IsInactive` with an additional `extra` -function, and is more or less equivalent to -`IsInactive{runtime}(T -> isinactivetype.extra(T) || extra(T))`. - -The type parameter `runtime` specifies whether the activity state of a type is queried at -runtime every time the callable is invoked (`true`), or if compile-time queries from earlier -calls can be reused (`false`). Runtime querying is necessary to pick up recently added -methods to `EnzymeRules.inactive_type`, but may incur a significant performance penalty and -is usually not needed unless `EnzymeRules.inactive_type` is extended interactively for types -that have previously been passed to an instance of `IsInactive{false}`. -""" -struct IsInactive{runtime,F} - extra::F - function IsInactive{runtime}( - extra::F=(@nospecialize(T) -> (@inline; false)) - ) where {runtime,F} - return new{runtime::Bool,F}(extra) - end -end - -function IsInactive(isinactivetype::IsInactive{runtime}, extra::F) where {runtime,F} - combinedextra(::Type{T}) where {T} = (isinactivetype.extra(T) || extra(T)) - return IsInactive{runtime}(combinedextra) -end - -@inline function (f::IsInactive{runtime,F})( - ::Type{T}, ::Val{nonactive}=Val(false) -) where {runtime,F,T,nonactive} - if runtime - # evaluate f.extra first, as guaranteed_*_nongen may incur runtime dispatch - if nonactive - return f.extra(T) || guaranteed_nonactive_nongen(T, nothing) - else - return f.extra(T) || guaranteed_const_nongen(T, nothing) - end - else - # evaluate guaranteed_* first, as these are always known at compile time - if nonactive - return guaranteed_nonactive(T) || f.extra(T) - else - return guaranteed_const(T) || f.extra(T) - end - end -end +using ..Compiler: guaranteed_const_nongen, guaranteed_nonactive_nongen ### traits defining active leaf types for recursive_map @inline isdensearraytype(::Type{<:DenseArray}) = true @@ -105,7 +26,7 @@ end ::Val{Nout} xs::NTuple{Nin,T}, ::Val{copy_if_inactive}=Val(false), - isinactivetype=IsInactive{false}(), + isinactivetype=guaranteed_const_nongen, )::T newys = recursive_map( [seen::Union{Nothing,IdDict},] @@ -113,7 +34,7 @@ end ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive}=Val(false), - isinactivetype=IsInactive{false}(), + isinactivetype=guaranteed_const_nongen, )::T !!! warning @@ -208,10 +129,8 @@ that the type notionally represents. deep-copies such that the object reference graph is reproduced also within the inactive parts.) -* `isinactivetype` (optional): Callable determining which types are considered inactive and - thus treated according to `copy_if_inactive`. The [`IsInactive`](@ref) type is a - convenient helper for obtaining a callable with relevant semantics, but any callable that - maps types to `true` or `false` can be used. +* `isinactivetype` (optional): Callable mapping types to `Bool` to determines whether the + type should be treated according to `copy_if_inactive` (`true`) or recursed into (`false`). """ function recursive_map end @@ -226,7 +145,7 @@ function recursive_map( ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive::Val=Val(false), - isinactivetype::L=IsInactive{false}(), + isinactivetype::L=guaranteed_const_nongen, ) where {F,Nout,Nin,T,L} newys = if isinactivetype(T) recursive_map_inactive(nothing, ys, xs, copy_if_inactive) @@ -245,7 +164,7 @@ function recursive_map( ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive::Val=Val(false), - isinactivetype::L=IsInactive{false}(), + isinactivetype::L=guaranteed_const_nongen, ) where {F,Nout,Nin,T,L} # determine whether to continue recursion, copy/share, or retrieve from cache newys = if isinactivetype(T) @@ -511,12 +430,11 @@ end ### recursive_map!: fully in-place wrapper around recursive_map """ recursive_map!( - [seen::IdDict,] + [seen::Union{Nothing,IdDict},] f!!, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - ::Val{copy_if_inactive}=Val(false), - isinactivetype::IsInactive=IsInactive{false}(), + [::Val{copy_if_inactive},] )::Nothing !!! warning @@ -528,38 +446,29 @@ the function `f!!` over every differentiable value encountered and updating `(y1::T, y2::T, ..., yNout::T)`` in-place with the resulting values. This is a simple wrapper that verifies that `T` is a type where all differentiable values -can be updated in-place (this uses the `nonactive == true` mode of `isinactivetype`, see -[`IsInactive`](@ref) for details), calls `recursive_map`, and verifies that the returned -value is indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. - -Note that this wrapper only supports instances of [`IsInactive`](@ref) for the -`isinactivetype` argument, as this is the only way we can insure consistency between the -upfront compatibility check and actual behavior. If this is not appropriate, use -`recursive_map` directly. +can be updated in-place, calls `recursive_map`, and verifies that the returned value is +indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. """ function recursive_map!( - f!!::F, - ys::NTuple{Nout,T}, - xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val(false), - isinactivetype::IsInactive=IsInactive{false}(), -) where {F,Nout,Nin,T} - check_nonactive(T, isinactivetype) - newys = recursive_map(f!!, ys, xs, copy_if_inactive, isinactivetype) + f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactives::Vararg{Val,M} +) where {F,Nout,Nin,T,M} + @assert M <= 1 + check_nonactive(T) + newys = recursive_map(f!!, ys, xs, copy_if_inactives...) @assert newys === ys return nothing end function recursive_map!( - seen::IdDict, + seen::Union{Nothing,IdDict}, f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val(false), - isinactivetype::IsInactive=IsInactive{false}(), -) where {F,Nout,Nin,T} - check_nonactive(T, isinactivetype) - newys = recursive_map(seen, f!!, ys, xs, copy_if_inactive, isinactivetype) + copy_if_inactives::Vararg{Val,M}, +) where {F,Nout,Nin,T,M} + @assert M <= 1 + check_nonactive(T) + newys = recursive_map(seen, f!!, ys, xs, copy_if_inactives...) @assert newys === ys return nothing end @@ -670,8 +579,8 @@ Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized=true) = return nothing end -@inline function check_nonactive(::Type{T}, isinactivetype::IsInactive) where {T} - if !isinactivetype(T, Val(true)) #=nonactive=# +@inline function check_nonactive(::Type{T}) where {T} + if !guaranteed_nonactive_nongen(T) throw_nonactive() end return nothing @@ -695,51 +604,29 @@ end end ### EnzymeCore.make_zero(!) implementation -function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} - new = if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback - # IsInactive has precedence over isvectortype for consistency with recursive handler +function EnzymeCore.make_zero(prev::T, copy_if_inactives::Vararg{Val,M}) where {T,M} + @assert M <= 1 + new = if iszero(M) && !guaranteed_const_nongen(T) && isvectortype(T) # fallback + # guaranteed_const has precedence over isvectortype for consistency with recursive_map convert(T, zero(prev)) # convert because zero(prev)::T may fail when eltype(T) is abstract else - _make_zero_inner(prev, args...) + only(recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactives...)) end return new::T end -function EnzymeCore.make_zero!(val::T, args::Vararg{Any,M}) where {T,M} +function EnzymeCore.make_zero!(val::T, seens::Vararg{IdDict,M}) where {T,M} + @assert M <= 1 @assert !isscalartype(T) # not appropriate for in-place handler - if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback - # IsInactive has precedence over isvectortype for consistency with recursive handler + if iszero(M) && !guaranteed_const_nongen(T) && isvectortype(T) # fallback + # isinactivetype has precedence over isvectortype for consistency with recursive_map fill!(val, false) else - _make_zero_inner!(val, args...) + recursive_map!(seens..., _make_zero!!, (val,), (val,)) end return nothing end -@inline function _make_zero_inner( - prev::T, copy_if_inactive::Val=Val(false), ::Val{runtime_inactive}=Val(false) -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - news = recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) - return only(news)::T -end - -@inline function _make_zero_inner!( - val::T, ::Val{runtime_inactive}=Val(false) -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - recursive_map!(_make_zero!!, (val,), (val,), Val(false), isinactivetype) - return nothing -end - -@inline function _make_zero_inner!( - val::T, seen::IdDict, ::Val{runtime_inactive}=Val(false) -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - recursive_map!(seen, _make_zero!!, (val,), (val,), Val(false), isinactivetype) - return nothing -end - function _make_zero!!(prev::T) where {T} @assert isvectortype(T) # otherwise infinite loop return (EnzymeCore.make_zero(prev)::T,) @@ -755,15 +642,10 @@ end # alternative entry point for passing custom IdDict function EnzymeCore.make_zero( - ::Type{T}, - seen::IdDict, - prev::T, - copy_if_inactive::Val=Val(false), - ::Val{runtime_inactive}=Val(false), -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - news = recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) - return only(news)::T + ::Type{T}, seen::IdDict, prev::T, copy_if_inactives::Vararg{Val,M} +) where {T,M} + @assert M <= 1 + return only(recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactives...))::T end end # module RecursiveMaps diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index fc6684d1e5..fc903e5139 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -143,7 +143,7 @@ function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) return true end -mutable struct CustomVector{T} <: AbstractVector{T} +mutable struct CustomVector{T} data::Vector{T} end @@ -551,43 +551,35 @@ function test_make_zero() @test a[1] === 1.0 # no mutation of original end @testset "runtime inactive" begin - a = [1.0] - v = CustomVector(a) - with_logger(SimpleLogger(Warn)) do # silence @info "make_zero(::CustomVector)" - # ensure compile-time methods are evaluated while CustomVector is considered active - @assert !EnzymeRules.inactive_type(CustomVector) - v_makez = make_zero(v, Val(false), Val(false)) - @assert v_makez == CustomVector([0.0]) - - # verify that runtime methods also see CustomVector as active - v_makez = make_zero(v, Val(false), Val(true)) - @test v_makez == CustomVector([0.0]) - - # mark CustomVector as inactive - @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = true - - # runtime_inactive == false => redefined inactive_type should have no effect - v_makez = @invokelatest make_zero(v, Val(false), Val(false)) - @test v_makez == CustomVector([0.0]) - - # runtime_inactive == true => redefined inactive_type should take effect: - # CustomVector considered inactive and won't be zeroed, but - # shared/copied according to copy_if_inactive instead - v_makez = @invokelatest make_zero(v, Val(false), Val(true)) - @test v_makez === v - v_makez = @invokelatest make_zero(v, Val(true), Val(true)) - @test v_makez !== v - @test v_makez == CustomVector([1.0]) - - # mark CustomVector as active again - @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = false - - # verify that both compile-time and runtime methods see CustomVector as active - v_makez = @invokelatest make_zero(v, Val(false), Val(false)) - @test v_makez == CustomVector([0.0]) - v_makez = @invokelatest make_zero(v, Val(false), Val(true)) - @test v_makez == CustomVector([0.0]) - end + # verify that MutableWrapper is seen as active + a = MutableWrapper(1.0) + @assert !EnzymeRules.inactive_type(typeof(a)) + a_makez = make_zero(a) + @test a_makez == MutableWrapper(0.0) + + # mark MutableWrapper as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true + + # verify that MutableWrapper is seen as inactive and shared/copied according to + # copy_if_inactive + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a) + @test a_makez == a # equal + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a, Val(false)) + @test a_makez === a # identical + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a, Val(true)) + @test a_makez !== a # not identical + @test a_makez == a # but equal + + # mark MutableWrapper as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false + + # verify that MutableWrapper is seen as active + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a) + @test a_makez == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin @testset "array w inactive/active/mutable/unassigned" begin @@ -850,45 +842,27 @@ function test_make_zero!() @test a[1] === 0.0 # correct value end @testset "runtime inactive" begin - a = [1.0] - v = CustomVector(a) - with_logger(SimpleLogger(Warn)) do # silence @info "make_zero!(::CustomVector)" - # ensure compile-time methods are evaluated while CustomVector is considered active - @assert !EnzymeRules.inactive_type(CustomVector) - make_zero!(v, Val(false)) - @assert v == CustomVector([0.0]) - - # verify that runtime methods also see CustomVector as active - v.data[1] = 1.0 - make_zero!(v, Val(true)) - @test v == CustomVector([0.0]) - - # mark CustomVector as inactive - @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = true - - # runtime_inactive == false => compile-time methods still used, redefined - # inactive_type should have no effect - v.data[1] = 1.0 - @invokelatest make_zero!(v, Val(false)) - @test v == CustomVector([0.0]) - - # runtime_inactive == true => redefined inactive_type should take effect - # CustomVector considered inactive and won't be zeroed - v.data[1] = 1.0 - @invokelatest make_zero!(v, Val(true)) - @test v == CustomVector([1.0]) - - # mark CustomVector as active again - @eval @inline EnzymeRules.inactive_type(::Type{<:CustomVector}) = false - - # verify that both compile-time and runtime methods see CustomVector as active - v.data[1] = 1.0 - @invokelatest make_zero!(v, Val(false)) - @test v == CustomVector([0.0]) - v.data[1] = 1.0 - @invokelatest make_zero!(v, Val(true)) - @test v == CustomVector([0.0]) - end + # verify that MutableWrapper is seen as active + a = MutableWrapper(1.0) + @assert !EnzymeRules.inactive_type(typeof(a)) + make_zero!(a) + @test a == MutableWrapper(0.0) + + # mark MutableWrapper as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true + + # verify that MutableWrapper is seen as inactive + a.x = 1.0 + @invokelatest make_zero!(a) + @test a == MutableWrapper(1.0) + + # mark MutableWrapper as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false + + # verify that MutableWrapper is seen as active + a.x = 1.0 + @invokelatest make_zero!(a) + @test a == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin @testset "array w inactive/active/mutable/unassigned" begin From c0fa471ceb5b844f2bfde7cd4c143f80cc0496f0 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sun, 12 Jan 2025 20:50:22 -0800 Subject: [PATCH 04/17] Refactor to reduce generated lines --- src/typeutils/recursive_maps.jl | 292 +++++++++++++++++--------------- 1 file changed, 157 insertions(+), 135 deletions(-) diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index 04c8998b87..8b680e048b 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -4,9 +4,6 @@ using EnzymeCore: EnzymeCore, isvectortype, isscalartype using ..Compiler: guaranteed_const_nongen, guaranteed_nonactive_nongen ### traits defining active leaf types for recursive_map -@inline isdensearraytype(::Type{<:DenseArray}) = true -@inline isdensearraytype(::Type) = false - @inline EnzymeCore.isvectortype(::Type{T}) where {T} = isscalartype(T) @inline function EnzymeCore.isvectortype(::Type{<:DenseArray{U}}) where {U} return isbitstype(U) && isscalartype(U) @@ -195,90 +192,116 @@ end return newys::NTuple{Nout,T} end -@generated function recursive_map_mutable( +@inline function recursive_map_mutable( seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L ) where {F,Nout,Nin,T,L} @assert ismutabletype(T) - iteration_i = quote - @inbounds if isinitialized(x1, i) - check_allinitialized(xtail, i) - newys_i = recursive_map_index(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) - setitems!(newys, i, newys_i) - elseif hasvalues(ys) - check_allinitialized(ys, i, false) + if !hasvalues(ys) && !(T <: DenseArray) && all(isbitstype, fieldtypes(T)) + # fast path for out-of-place handling when all fields are bitstypes, which rules + # out undefined fields and circular references + newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + maybecache!(seen, newys, xs) + else + newys = if hasvalues(ys) + ys + else + x1 = first(xs) + ntuple(_ -> (@inline; _similar(x1)), Val(Nout)) end + maybecache!(seen, newys, xs) + recursive_map_mutable_inner!(seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) end - return quote - @inline - if !hasvalues(ys) && !isdensearraytype(T) && all(isbitstype, fieldtypes(T)) - # fast path for out-of-place handling when all fields are bitstypes, which rules - # out undefined fields and circular references - newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) - maybecache!(seen, newys, xs) - else - x1, xtail = first(xs), Base.tail(xs) - newys = if hasvalues(ys) - ys - else - Base.@ntuple $Nout _ -> _similar(x1) + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_mutable_inner!( + seen, + f::F, + newys::NTuple{Nout,T}, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive, + isinactivetype::L, +) where {F,Nout,Nin,T<:DenseArray,L} + if (Nout == 1) && isbitstype(eltype(T)) + newy = only(newys) + if hasvalues(ys) + y = only(ys) + broadcast!(newy, y, xs...) do y_i, xs_i... + only(recursive_map(nothing, f, (y_i,), xs_i, copy_if_inactive, isinactivetype)) end - maybecache!(seen, newys, xs) - if isdensearraytype(T) - if (Nout == 1) && isbitstype(eltype(T)) - recursive_map_broadcast!( - f, newys, ys, xs, copy_if_inactive, isinactivetype - ) - else - for i in eachindex(newys..., xs...) - $iteration_i - end - end - else # unrolled loop over struct fields - Base.Cartesian.@nexprs $(fieldcount(T)) i -> $iteration_i + else + broadcast!(newy, xs...) do xs_i... + only(recursive_map(nothing, f, Val(1), xs_i, copy_if_inactive, isinactivetype)) end end - return newys::NTuple{Nout,T} + else + @inbounds for i in eachindex(newys..., xs...) + recursive_map_item!(i, seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + end + end + return nothing +end + +@generated function recursive_map_mutable_inner!( + seen, + f::F, + newys::NTuple{Nout,T}, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive, + isinactivetype::L, +) where {F,Nout,Nin,T,L} + return quote + @inline + Base.Cartesian.@nexprs $(fieldcount(T)) i -> @inbounds begin + recursive_map_item!(i, seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + end + return nothing end end -@generated function recursive_map_immutable( +@inline function recursive_map_immutable( seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L ) where {F,Nout,Nin,T,L} @assert !ismutabletype(T) + nf = fieldcount(T) + if nf == 0 # nothing to do (also no known way to hit this branch) + newys = recursive_map_inactive(seen, ys, xs, Val(false)) + else + newys = if isinitialized(first(xs), nf) # fast path when all fields are defined + check_allinitialized(Base.tail(xs), nf) + recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + else + recursive_map_immutable_inner(seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + # maybecache! _should_ be a no-op here; call it anyway for consistency + maybecache!(seen, newys, xs) + end + return newys::NTuple{Nout,T} +end + +@generated function recursive_map_immutable_inner( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} nf = fieldcount(T) return quote @inline - if $nf == 0 # nothing to do (also no known way to hit this branch) - newys = recursive_map_inactive(nothing, ys, xs, Val(false)) - else - x1, xtail = first(xs), Base.tail(xs) - if isinitialized(x1, $nf) # fast path when all fields are defined - check_allinitialized(xtail, $nf) - newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + x1, xtail = first(xs), Base.tail(xs) + fields = Base.@ntuple $Nout _ -> Vector{Any}(undef, $(nf - 1)) + Base.Cartesian.@nexprs $(nf - 1) i -> begin # unrolled loop over struct fields + @inbounds if isinitialized(x1, i) + check_allinitialized(xtail, i) + newys_i = recursive_map_item( + i, seen, f, ys, xs, copy_if_inactive, isinactivetype + ) + Base.Cartesian.@nexprs $Nout j -> (fields[j][i] = newys_i[j]) else - Base.Cartesian.@nexprs $Nout j -> (fields_j = Vector{Any}(undef, $(nf - 1))) - Base.Cartesian.@nexprs $(nf - 1) i -> begin # unrolled loop over struct fields - @inbounds if isinitialized(x1, i) - check_allinitialized(xtail, i) - newys_i = recursive_map_index( - i, seen, f, ys, xs, copy_if_inactive, isinactivetype - ) - Base.Cartesian.@nexprs $Nout j -> (fields_j[i] = newys_i[j]) - else - ndef = i - 1 # rest of tail must be undefined values - @goto done # break out of unrolled loop - end - end - ndef = $(nf - 1) # loop didn't break, only last field is undefined - @label done - newys = Base.@ntuple $Nout j -> begin - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, fields_j, ndef)::T - end + return new_structvs(T, fields, i - 1) end - # maybecache! _should_ be a no-op here; call it anyway for consistency - maybecache!(seen, newys, xs) end - return newys::NTuple{Nout,T} + @assert !isinitialized(x1, $nf) + return new_structvs(T, fields, $(nf - 1)) end end @@ -289,10 +312,8 @@ end nf = fieldcount(T) return quote @inline - Base.Cartesian.@nexprs $nf i -> begin - newys_i = @inbounds recursive_map_index( - i, seen, f, ys, xs, copy_if_inactive, isinactivetype - ) + Base.Cartesian.@nexprs $nf i -> @inbounds begin + newys_i = recursive_map_item(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) end newys = Base.@ntuple $Nout j -> begin $(Expr(:splatnew, :T, :(Base.@ntuple $nf i -> newys_i[j]))) @@ -301,36 +322,27 @@ end end end -@inline function recursive_map_broadcast!( - f::F, newys::NTuple{1,T}, ys::YS{1,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nin,T,L} - # broadcast recursive_map over array-like inputs with isbits elements - @assert isdensearraytype(T) - @assert isbitstype(eltype(T)) - newy = first(newys) - if hasvalues(ys) - @assert newys === ys - broadcast!( - (newy_i, xs_i...) -> first(recursive_map_barrier!!( - nothing, f, copy_if_inactive, isinactivetype, Val(1), newy_i, xs_i... - )), - newy, - newy, - xs..., - ) - else - broadcast!( - (xs_i...,) -> first(recursive_map_barrier( - nothing, f, copy_if_inactive, isinactivetype, Val(1), xs_i... - )), - newy, - xs..., - ) +Base.@propagate_inbounds function recursive_map_item!( + i, + seen, + f::F, + newys::NTuple{Nout,T}, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive, + isinactivetype::L, +) where {F,Nout,Nin,T,L} + if isinitialized(first(xs), i) + check_allinitialized(Base.tail(xs), i) + newys_i = recursive_map_item(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) + setitems!(newys, i, newys_i) + elseif hasvalues(ys) + check_allinitialized(ys, i, false) end return nothing end -Base.@propagate_inbounds function recursive_map_index( +Base.@propagate_inbounds function recursive_map_item( i, seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L ) where {F,Nout,Nin,T,L} # recurse into the xs and apply recursive_map to items with index i @@ -367,10 +379,10 @@ end # specialized methods to optimize the common cases Nout == 1 and Nout == 2 function recursive_map_barrier!!( - seen, f::F, copy_if_inactive::Val, isinactivetype::L, ::Val{1}, yi::ST, xs_i::Vararg{ST,Nin} + seen, f::F, copy_if_inactive::Val, isinactivetype::L, ::Val{1}, y_i::ST, xs_i::Vararg{ST,Nin} ) where {F,Nin,ST,L} return recursive_map( - seen, f, (yi,), xs_i, copy_if_inactive, isinactivetype + seen, f, (y_i,), xs_i, copy_if_inactive, isinactivetype )::NTuple{1,ST} end @@ -394,14 +406,13 @@ end seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T} ) where {F,Nout,Nin,T} # apply the mapped function to leaf values - newys = if !hasvalues(ys) || isbitstype(T) || isscalartype(T) - f(xs...)::NTuple{Nout,T} + if !hasvalues(ys) || isbitstype(T) || isscalartype(T) + newys = f(xs...)::NTuple{Nout,T} else # !isbitstype(T) - newys_ = f(ys..., xs...)::NTuple{Nout,T} + newys = f(ys..., xs...)::NTuple{Nout,T} if ismutabletype(T) - @assert newys_ === ys + @assert newys === ys end - newys_ end maybecache!(seen, newys, xs) return newys::NTuple{Nout,T} @@ -413,18 +424,20 @@ end return ys::NTuple{Nout,T} end -@generated function recursive_map_inactive( - seen, ::Val{Nout}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive} +@inline function recursive_map_inactive( + seen, ::Val{Nout}, (x1,)::NTuple{Nin,T}, ::Val{copy_if_inactive} ) where {Nout,Nin,T,copy_if_inactive} - return quote - @inline - y = if copy_if_inactive && !isbitstype(T) - Base.deepcopy_internal(first(xs), isnothing(seen) ? IdDict() : seen) + @inline + y = if copy_if_inactive && !isbitstype(T) + if isnothing(seen) + deepcopy(x1) else - first(xs) + Base.deepcopy_internal(x1, seen) end - return (Base.@ntuple $Nout _ -> y)::NTuple{Nout,T} + else + x1 end + return ntuple(_ -> (@inline; y), Val(Nout))::NTuple{Nout,T} end ### recursive_map!: fully in-place wrapper around recursive_map @@ -474,6 +487,15 @@ function recursive_map!( end ### recursive_map helpers +@generated function new_structvs(::Type{T}, fields::NTuple{N,Vector{Any}}, nfields_) where {T,N} + return quote + @inline + return Base.@ntuple $N j -> begin + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, fields[j], nfields_)::T + end + end +end + @inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T @inline _similar(x::T) where {T<:DenseArray} = similar(x)::T Base.@propagate_inbounds isinitialized(x, i) = isdefined(x, i) @@ -492,22 +514,22 @@ Base.@propagate_inbounds function setfield_force!(x::T, i, v) where {T} return nothing end -Base.@propagate_inbounds function getitems(xs::Tuple{T,T,Vararg{T,N}}, i) where {T,N} - return (getitem(first(xs), i), getitems(Base.tail(xs), i)...) +Base.@propagate_inbounds function getitems((x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i) where {T,N} + return (getitem(x1, i), getitems(xtail, i)...) end -Base.@propagate_inbounds getitems(xs::Tuple{T}, i) where {T} = (getitem(only(xs), i),) +Base.@propagate_inbounds getitems((x1,)::Tuple{T}, i) where {T} = (getitem(x1, i),) Base.@propagate_inbounds function setitems!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - xs::Tuple{T,T,Vararg{T,N}}, i, vs::Tuple{ST,ST,Vararg{ST,N}} + (x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i, (v1, vtail...)::Tuple{ST,ST,Vararg{ST,N}} ) where {T,ST,N} - setitem!(first(xs), i, first(vs)) - setitems!(Base.tail(xs), i, Base.tail(vs)) + setitem!(x1, i, v1) + setitems!(xtail, i, vtail) return nothing end -Base.@propagate_inbounds function setitems!(xs::Tuple{T}, i, vs::Tuple{ST}) where {T,ST} - setitem!(only(xs), i, only(vs)) +Base.@propagate_inbounds function setitems!((x1,)::Tuple{T}, i, (v1,)::Tuple{ST}) where {T,ST} + setitem!(x1, i, v1) return nothing end @@ -520,28 +542,28 @@ end @inline shouldcache(::IdDict, ::Type{T}) where {T} = iscachedtype(T) @inline shouldcache(::Nothing, ::Type{T}) where {T} = false -@inline function maybecache!(seen, newys::NTuple{Nout,T}, xs::NTuple{Nin,T}) where {Nout,Nin,T} +@inline function maybecache!(seen, newys::NTuple{Nout,T}, (x1, xtail...)::NTuple{Nin,T}) where {Nout,Nin,T} if shouldcache(seen, T) - if (Nout == 1) && (Nin == 1) - seen[only(xs)] = only(newys) + seen[x1] = if (Nout == 1) && (Nin == 1) + only(newys) else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - seen[first(xs)] = (newys..., Base.tail(xs)...) + (newys..., xtail...) end end return nothing end -@inline function hascache(seen, xs::NTuple{Nin,T}) where {Nin,T} - return shouldcache(seen, T) ? haskey(seen, first(xs)) : false +@inline function hascache(seen, (x1,)::NTuple{Nin,T}) where {Nin,T} + return shouldcache(seen, T) ? haskey(seen, x1) : false end -@inline function getcached(seen::IdDict, ::Val{Nout}, xs::NTuple{Nin,T}) where {Nout,Nin,T} +@inline function getcached(seen::IdDict, ::Val{Nout}, (x1, xtail...)::NTuple{Nin,T}) where {Nout,Nin,T} newys = if (Nout == 1) && (Nin == 1) - (seen[only(xs)]::T,) + (seen[x1]::T,) else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - cache = seen[first(xs)]::NTuple{(Nout + Nin - 1),T} + cache = seen[x1]::NTuple{(Nout + Nin - 1),T} cachedtail = cache[(Nout+1):end] - check_identical(cachedtail, Base.tail(xs)) # check compatible layout + check_identical(cachedtail, xtail) # check compatible layout cache[1:Nout] end return newys::NTuple{Nout,T} @@ -556,17 +578,17 @@ Base.@propagate_inbounds function check_initialized(x, i, initialized=true) end Base.@propagate_inbounds function check_allinitialized( # TODO: hit this when VectorSpace implemented - xs::Tuple{T,T,Vararg{T,N}}, i, initialized=true + (x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i, initialized=true ) where {T,N} - check_initialized(first(xs), i, initialized) - check_allinitialized(Base.tail(xs), i, initialized) + check_initialized(x1, i, initialized) + check_allinitialized(xtail, i, initialized) return nothing end Base.@propagate_inbounds function check_allinitialized( - xs::Tuple{T}, i, initialized=true + (x1,)::Tuple{T}, i, initialized=true ) where {T} - check_initialized(only(xs), i, initialized) + check_initialized(x1, i, initialized) return nothing end From f852e50202397e12e160851e8096e97cf788ca51 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Mon, 13 Jan 2025 09:33:15 -0800 Subject: [PATCH 05/17] Only support Nout in (1, 2) --- src/typeutils/recursive_maps.jl | 84 ++++++++++++++++----------------- 1 file changed, 41 insertions(+), 43 deletions(-) diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index 8b680e048b..418fb6ea8f 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -40,20 +40,20 @@ end Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping the function `f` over every differentiable value encountered and building `Nout` new objects -`(y1::T, y2::T, ..., yNout::T)` from the resulting values -`(y1_i, ..., yNout_i) = f(x1_i, ..., xNin_i)`. +`(y1::T, ...)` from the resulting values `(y1_i, ...) = f(x1_i, ..., xNin_i)`. Only +`Nout == 1` and `Nout == 2` are supported. The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered differentiable leaf nodes at which recursion terminates and `f` is invoked. See the docstring for [`EnzymeCore.isvectortype`](@ref) and the related [`EnzymeCore.isscalartype`](@ref) for more information. -A tuple of existing objects `ys = (y1::T, ..., yNout::T)` can be passed, in which case the -`ys` are updated "partially-in-place": any parts of the `ys` that are mutable or -non-differentiable are reused in the returned object tuple `newys`, while immutable -differentiable parts are handled out-of-place as if the `ys` were not passed (this can be -seen as a recursive generalization of the BangBang.jl idiom). If `T` itself is a mutable -type, the `ys` are modified in-place and returned, such that `newys === ys`. +A tuple of existing objects `ys = (y1::T, ...)` can be passed, in which case the `ys` are +updated "partially-in-place": any parts of the `ys` that are mutable or non-differentiable +are reused in the returned object tuple `newys`, while immutable differentiable parts are +handled out-of-place as if the `ys` were not passed (this can be seen as a recursive +generalization of the BangBang.jl idiom). If `T` itself is a mutable type, the `ys` are +modified in-place and returned, such that `newys === ys`. The recursion and mapping operates on the structure of `T` as defined by struct fields and plain array elements, not on the values provided through an iteration or array interface. @@ -74,13 +74,13 @@ that the type notionally represents. will likely cause errors. This is useful only in specific cases. * `f`: Function mapping leaf nodes within the `xs` to the corresponding leaf nodes in the - `ys`, that is, `(y1_i, ..., yNout_i) = f(x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}`. - The function `f` must be applicable to the type of every leaf node, and must return a - tuple of values of the same type as its arguments. + `ys`, that is, `(y1_i, ...) = f(x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}`. The function + `f` must be applicable to the type of every leaf node, and must return a tuple of values + of the same type as its arguments. When an existing object tuple `ys` is passed and contains leaf nodes of a non-isbits non-scalar type `U`, `f` should also have a partially-in-place method - `(newy1_i, ..., newyNout_i) === f(y1_i::U, ..., yNout_i::U, x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}` + `(newy1_i, ...) === f(y1_i::U, ..., yNout_i::U, x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}` that modifies and reuses any mutable parts of the `yj_i`; in particular, if `U` is a mutable type, this method should return `newyj_i === yj_i`. If a non-isbits type `U` should always be handled using the out-of-place signature, extend @@ -90,10 +90,10 @@ that the type notionally represents. details about leaf types and scalar types. * `::Val{Nout}` or `ys::NTuple{Nout,T}`: For out-of-place operation, pass `Val(Nout)` where - `Nout` is the length of the tuple returned by `f`, that is, the length of the expected - return value `ys` (this is required; `Nout` never inferred). For partially-in-place - operation, pass the existing tuple `ys::NTuple{Nout,T}` containing the values to be - modified. + `Nout in (1, 2)` is the length of the tuple returned by `f`, that is, the length of the + expected return value `ys` (this is required; `Nout` never inferred). For + partially-in-place operation, pass the existing tuple `ys::NTuple{Nout,T}` containing the + values to be modified. * `xs::NTuple{N,T}`: Tuple of `N` objects of the same type `T` over which `f` is mapped. @@ -144,6 +144,7 @@ function recursive_map( copy_if_inactive::Val=Val(false), isinactivetype::L=guaranteed_const_nongen, ) where {F,Nout,Nin,T,L} + check_nout(ys) newys = if isinactivetype(T) recursive_map_inactive(nothing, ys, xs, copy_if_inactive) elseif isvectortype(T) || isbitstype(T) @@ -164,6 +165,7 @@ function recursive_map( isinactivetype::L=guaranteed_const_nongen, ) where {F,Nout,Nin,T,L} # determine whether to continue recursion, copy/share, or retrieve from cache + check_nout(ys) newys = if isinactivetype(T) recursive_map_inactive(seen, ys, xs, copy_if_inactive) elseif isbitstype(T) # no object identity to to track in this branch @@ -350,55 +352,39 @@ Base.@propagate_inbounds function recursive_map_item( newys_i = if hasvalues(ys) && isinitialized(first(ys), i) check_allinitialized(Base.tail(ys), i) ys_i = getitems(ys, i) - recursive_map_barrier!!( - seen, f, copy_if_inactive, isinactivetype, Val(Nout), ys_i..., xs_i... - ) + recursive_map_barrier!!(seen, f, ys_i..., copy_if_inactive, isinactivetype, xs_i...) else - recursive_map_barrier(seen, f, copy_if_inactive, isinactivetype, Val(Nout), xs_i...) + recursive_map_barrier(seen, f, Val(Nout), copy_if_inactive, isinactivetype, xs_i...) end return newys_i end # function barriers such that abstractly typed items trigger minimal runtime dispatch function recursive_map_barrier( - seen, f::F, copy_if_inactive::Val, isinactivetype::L, ::Val{Nout}, xs_i::Vararg{ST,Nin} + seen, f::F, ::Val{Nout}, copy_if_inactive::Val, isinactivetype::L, xs_i::Vararg{ST,Nin} ) where {F,Nout,Nin,ST,L} return recursive_map( seen, f, Val(Nout), xs_i, copy_if_inactive, isinactivetype )::NTuple{Nout,ST} end -function recursive_map_barrier!!( # TODO: hit this when VectorSpace implemented - seen, f::F, copy_if_inactive, isinactivetype::L, ::Val{Nout}, yxs_i::Vararg{ST,M} -) where {F,Nout,M,ST,L} - ys_i, xs_i = yxs_i[1:(Nout::Int)], yxs_i[((Nout::Int)+1):end] - return recursive_map( - seen, f, ys_i, xs_i, copy_if_inactive, isinactivetype - )::NTuple{Nout,ST} -end - -# specialized methods to optimize the common cases Nout == 1 and Nout == 2 function recursive_map_barrier!!( - seen, f::F, copy_if_inactive::Val, isinactivetype::L, ::Val{1}, y_i::ST, xs_i::Vararg{ST,Nin} + seen, f::F, y_i::ST, copy_if_inactive::Val, isinactivetype::L, xs_i::Vararg{ST,Nin} ) where {F,Nin,ST,L} - return recursive_map( - seen, f, (y_i,), xs_i, copy_if_inactive, isinactivetype - )::NTuple{1,ST} + return recursive_map(seen, f, (y_i,), xs_i, copy_if_inactive, isinactivetype)::NTuple{1,ST} end function recursive_map_barrier!!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented seen, f::F, - copy_if_inactive::Val, - isinactivetype::L, - ::Val{2}, y1_i::ST, y2_i::ST, + copy_if_inactive::Val, + isinactivetype::L, xs_i::Vararg{ST,Nin} ) where {F,Nin,ST,L} - return recursive_map( - seen, f, (y1_i, y2_i), xs_i, copy_if_inactive, isinactivetype - )::NTuple{2,ST} + ys_i = (y1_i, y2_i) + return recursive_map(seen, f, ys_i, xs_i, copy_if_inactive, isinactivetype)::NTuple{2,ST} end ## recursion base case handlers @@ -455,13 +441,15 @@ end stability guarantees Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping -the function `f!!` over every differentiable value encountered and updating -`(y1::T, y2::T, ..., yNout::T)`` in-place with the resulting values. +the function `f!!` over every differentiable value encountered and updating `(y1::T, ...)` +in-place with the resulting values. This is a simple wrapper that verifies that `T` is a type where all differentiable values can be updated in-place, calls `recursive_map`, and verifies that the returned value is indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. """ +function recursive_map! end + function recursive_map!( f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactives::Vararg{Val,M} ) where {F,Nout,Nin,T,M} @@ -570,6 +558,12 @@ end end ## argument validation +@inline function check_nout(::YS{Nout}) where {Nout} + if Nout > 2 + throw_nout() + end +end + Base.@propagate_inbounds function check_initialized(x, i, initialized=true) if isinitialized(x, i) != initialized throw_initialized() # TODO: hit this when VectorSpace implemented @@ -609,6 +603,10 @@ end end # TODO: hit all of these via check_* when VectorSpace implemented +@noinline function throw_nout() + throw(ArgumentError("recursive_map(!) only supports mapping to 1 or 2 outputs")) +end + @noinline function throw_initialized() msg = "recursive_map(!) called on objects whose undefined fields/unassigned elements " msg *= "don't line up" From d24baec8156afc5230c64979f53a6e5bbf9aa7f2 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sun, 19 Jan 2025 15:11:56 -0800 Subject: [PATCH 06/17] Revert "Simplify: remove IsInactive, always use *_nongen" With some tweaks to keep good ideas from the commits since then --- lib/EnzymeCore/src/EnzymeCore.jl | 30 ++++- src/analyses/activity.jl | 4 +- src/typeutils/recursive_add.jl | 42 ++++--- src/typeutils/recursive_maps.jl | 183 +++++++++++++++++++++++++------ test/recursive_maps.jl | 64 ++++++----- 5 files changed, 246 insertions(+), 77 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 6829f88fea..99694c02bd 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -506,14 +506,30 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ - make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T - make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T + make_zero( + prev::T, ::Val{copy_if_inactive}=Val(false), ::Val{runtime_inactive}=Val(false) + )::T + make_zero( + ::Type{T}, + seen::IdDict, + prev::T, + ::Val{copy_if_inactive}=Val(false), + ::Val{runtime_inactive}=Val(false), + )::T Recursively make a copy of the value `prev::T` in which all differentiable values are zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s instance (the default) or make a copy. +The argument `runtime_inactive` specifies whether each constituent type is checked for being +guaranteed inactive at runtime for every call to `make_zero`, or if this can be checked once +at compile-time and reused across multiple calls to `make_zero` and related functions (the +default). Runtime checks are necessary to pick up recently added methods to +`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually +not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have +previously been passed to `make_zero` or related functions. + Extending this method for custom types is rarely needed. If you implement a new type, such as a GPU array type, on which `make_zero` should directly invoke `zero` when the eltype is scalar, it is sufficient to implement `Base.zero` and make sure your type subtypes @@ -523,13 +539,21 @@ scalar, it is sufficient to implement `Base.zero` and make sure your type subtyp function make_zero end """ - make_zero!(val::T, [seen::IdDict])::Nothing + make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive}=Val(false))::Nothing Recursively set a variable's differentiable values to zero. Only applicable for types `T` that are mutable or hold all differentiable values in mutable storage (e.g., `Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over parts of `val` that are guaranteed to be inactive. +The argument `runtime_inactive` specifies whether each constituent type is checked for being +guaranteed inactive at runtime for every call to `make_zero!`, or if this can be checked once +at compile-time and reused across multiple calls to `make_zero!` and related functions (the +default). Runtime checks are necessary to pick up recently added methods to +`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually +not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have +previously been passed to `make_zero!` or related functions. + Extending this method for custom types is rarely needed. If you implement a new mutable type, such as a GPU array type, on which `make_zero!` should directly invoke `fill!(x, false)` when the eltype is scalar, it is sufficient to implement `Base.zero`, diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 8fa5331d77..2f2b8e6ec8 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -414,7 +414,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_c return res end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world=nothing)::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world)::Bool where {T} rt = active_reg_inner(T, (), world) res = rt == AnyState return res @@ -427,7 +427,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world=nothing)::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T} rt = Enzyme.Compiler.active_reg_inner(T, (), world) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 6de0ca910b..cb421728eb 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -1,7 +1,7 @@ using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map! """ - recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const_nongen) + recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const) !!! warning Internal function, documented for developer convenience but not covered by semver API @@ -18,9 +18,7 @@ The optional argument `forcelhs` takes a callable such that if `forcelhs(S) == t such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies to non-differentiable values. """ -function recursive_add( - x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const_nongen -) where {T,F,L} +function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L} function addf(xi::S, yi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((xi + f(yi))::S,) @@ -29,7 +27,8 @@ function recursive_add( end """ - accumulate_seen!(f, seen::IdDict) + accumulate_seen!(f, seen::IdDict, ::Val{runtime_inactive}=Val(false)) + accumulate_seen!(f, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive) !!! warning Internal function, documented for developer convenience but not covered by semver API @@ -44,16 +43,34 @@ returned value. The recursion stops at instances of types that are themselves cached by `make_zero` (`recursive_map`), as these objects should have their own entries in `seen`. The recursion -also stops at inactive objects that not be zeroed by `make_zero`. +also stops at inactive objects that would not be zeroed by `make_zero`. + +If the optional `::Val{runtime_inactive}` argument was passed to `make_zero`, the same value +should be passed to `accumulate_seen` for consistency. If needed, a custom +`RecursiveMaps.IsInactive` instance can be provided instead. """ -function accumulate_seen!(f::F, seen::IdDict) where {F} +function accumulate_seen!( + f::F, seen::IdDict, ::Val{runtime_inactive}=Val(false) +) where {F,runtime_inactive} + accumulate_seen!(f, seen, RecursiveMaps.IsInactive{runtime_inactive}()) + return nothing +end + +function accumulate_seen!( + f::F, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive +) where {F} + isinactive_or_cachedtype = RecursiveMaps.IsInactive( + isinactivetype, RecursiveMaps.iscachedtype + ) for (k, v) in seen - _accumulate_seen_item!(f, k, v) + _accumulate_seen_item!(f, k, v, isinactivetype, isinactive_or_cachedtype) end return nothing end -function _accumulate_seen_item!(f::F, k::T, v::T) where {F,T} +function _accumulate_seen_item!( + f::F, k::T, v::T, isinactivetype, isinactive_or_cachedtype +) where {F,T} function addf!!(ki::S, vi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((ki .+ f.(vi))::S,) @@ -65,11 +82,8 @@ function _accumulate_seen_item!(f::F, k::T, v::T) where {F,T} ki .+= f.(vi) return (ki::S,) end - @inline function isinactive_or_cachedtype(::Type{T}) where {T} - return guaranteed_const_nongen(T) || RecursiveMaps.iscachedtype(T) - end - RecursiveMaps.check_nonactive(T) - if !guaranteed_const_nongen(T) + RecursiveMaps.check_nonactive(T, isinactivetype) + if !isinactivetype(T) newks = RecursiveMaps.recursive_map_inner( nothing, addf!!, (k,), (k, v), Val(false), isinactive_or_cachedtype ) diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index 418fb6ea8f..49acc04735 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -1,7 +1,86 @@ module RecursiveMaps using EnzymeCore: EnzymeCore, isvectortype, isscalartype -using ..Compiler: guaranteed_const_nongen, guaranteed_nonactive_nongen +using ..Compiler: guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, + guaranteed_nonactive_nongen + +### IsInactive: helper for creating consistent inactive/nonactive type checkers +""" + isinactivetype = IsInactive{runtime::Bool}(extra=(T -> false)) + isinactivetype = IsInactive(isinactivetype::IsInactive, extra) + +!!! warning + Internal type, documented for developer convenience but not covered by semver API + stability guarantees + +Create a callable `isinactivetype` such that `isinactivetype(T) == true` if the type `T` is +non-differentiable, that is, if differentiable values can never be reached from any instance +of the type (that is, the activity state of `T` is `AnyState`). + +The callable takes an optional argument `Val(nonactive::Bool)`, such that the full signature +is + +```julia +isinactivetype(::Type{T}, ::Val{nonactive}=Val(false))::Bool +``` + +Setting `nonactive == true` selects for _nonactive_ types, which is a superset of inactive +types that also includes types `T` where every differentiable value can be mutated without +creating a new instance of `T` (that is, the activity state of `T` is either `AnyState` or +`DupState`). + +The optional argument `extra` takes a function defining additional types that should be +treated as inactive regardless of their nominal activity state; that is, + +```julia +IsInactive{runtime}(extra)(T, args...) == IsInactive{runtime}()(T, args...) || extra(T) +``` + +The constructor `IsInactive(isinactivetype::IsInactive{runtime}, extra)` can be used to +extend an existing instance `isinactivetype::IsInactive` with an additional `extra` +function, and is more or less equivalent to +`IsInactive{runtime}(T -> isinactivetype.extra(T) || extra(T))`. + +The type parameter `runtime` specifies whether the activity state of a type is queried at +runtime every time the callable is invoked (`true`), or if compile-time queries from earlier +calls can be reused (`false`). Runtime querying is necessary to pick up recently added +methods to `EnzymeRules.inactive_type`, but may incur a significant performance penalty and +is usually not needed unless `EnzymeRules.inactive_type` is extended interactively for types +that have previously been passed to an instance of `IsInactive{false}`. +""" +struct IsInactive{runtime,F} + extra::F + function IsInactive{runtime}( + extra::F=(@nospecialize(T) -> (@inline; false)) + ) where {runtime,F} + return new{runtime::Bool,F}(extra) + end +end + +function IsInactive(isinactivetype::IsInactive{runtime}, extra::F) where {runtime,F} + combinedextra(::Type{T}) where {T} = (isinactivetype.extra(T) || extra(T)) + return IsInactive{runtime}(combinedextra) +end + +@inline function (f::IsInactive{runtime,F})( + ::Type{T}, ::Val{nonactive}=Val(false) +) where {runtime,F,T,nonactive} + if runtime + # evaluate f.extra first, as guaranteed_*_nongen may incur runtime dispatch + if nonactive + return f.extra(T) || guaranteed_nonactive_nongen(T, nothing) + else + return f.extra(T) || guaranteed_const_nongen(T, nothing) + end + else + # evaluate guaranteed_* first, as these are always known at compile time + if nonactive + return guaranteed_nonactive(T) || f.extra(T) + else + return guaranteed_const(T) || f.extra(T) + end + end +end ### traits defining active leaf types for recursive_map @inline EnzymeCore.isvectortype(::Type{T}) where {T} = isscalartype(T) @@ -23,7 +102,7 @@ end ::Val{Nout} xs::NTuple{Nin,T}, ::Val{copy_if_inactive}=Val(false), - isinactivetype=guaranteed_const_nongen, + isinactivetype=IsInactive{false}(), )::T newys = recursive_map( [seen::Union{Nothing,IdDict},] @@ -31,7 +110,7 @@ end ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive}=Val(false), - isinactivetype=guaranteed_const_nongen, + isinactivetype=IsInactive{false}(), )::T !!! warning @@ -126,8 +205,10 @@ that the type notionally represents. deep-copies such that the object reference graph is reproduced also within the inactive parts.) -* `isinactivetype` (optional): Callable mapping types to `Bool` to determines whether the +* `isinactivetype` (optional): Callable mapping types to `Bool` to determine whether the type should be treated according to `copy_if_inactive` (`true`) or recursed into (`false`). + The [`IsInactive`](@ref) type is a helper for obtaining a callable with relevant semantics, + but any callable that maps types to `true` or `false` can be used. """ function recursive_map end @@ -142,7 +223,7 @@ function recursive_map( ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive::Val=Val(false), - isinactivetype::L=guaranteed_const_nongen, + isinactivetype::L=IsInactive{false}(), ) where {F,Nout,Nin,T,L} check_nout(ys) newys = if isinactivetype(T) @@ -162,7 +243,7 @@ function recursive_map( ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive::Val=Val(false), - isinactivetype::L=guaranteed_const_nongen, + isinactivetype::L=IsInactive{false}(), ) where {F,Nout,Nin,T,L} # determine whether to continue recursion, copy/share, or retrieve from cache check_nout(ys) @@ -433,7 +514,8 @@ end f!!, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - [::Val{copy_if_inactive},] + ::Val{copy_if_inactive}=Val(false), + isinactivetype::IsInactive=IsInactive{false}(), )::Nothing !!! warning @@ -447,15 +529,23 @@ in-place with the resulting values. This is a simple wrapper that verifies that `T` is a type where all differentiable values can be updated in-place, calls `recursive_map`, and verifies that the returned value is indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. + +Note that this wrapper only supports instances of [`IsInactive`](@ref) for the +`isinactivetype` argument, as this is the only way we can insure consistency between the +upfront compatibility check and actual behavior. If this is not appropriate, use +`recursive_map` directly. """ function recursive_map! end function recursive_map!( - f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactives::Vararg{Val,M} -) where {F,Nout,Nin,T,M} - @assert M <= 1 - check_nonactive(T) - newys = recursive_map(f!!, ys, xs, copy_if_inactives...) + f!!::F, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val(false), + isinactivetype::IsInactive=IsInactive{false}(), +) where {F,Nout,Nin,T} + check_nonactive(T, isinactivetype) + newys = recursive_map(f!!, ys, xs, copy_if_inactive, isinactivetype) @assert newys === ys return nothing end @@ -465,11 +555,11 @@ function recursive_map!( f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - copy_if_inactives::Vararg{Val,M}, -) where {F,Nout,Nin,T,M} - @assert M <= 1 - check_nonactive(T) - newys = recursive_map(seen, f!!, ys, xs, copy_if_inactives...) + copy_if_inactive::Val=Val(false), + isinactivetype::IsInactive=IsInactive{false}(), +) where {F,Nout,Nin,T} + check_nonactive(T, isinactivetype) + newys = recursive_map(seen, f!!, ys, xs, copy_if_inactive, isinactivetype) @assert newys === ys return nothing end @@ -595,8 +685,8 @@ Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized=true) = return nothing end -@inline function check_nonactive(::Type{T}) where {T} - if !guaranteed_nonactive_nongen(T) +@inline function check_nonactive(::Type{T}, isinactivetype::IsInactive) where {T} + if !isinactivetype(T, Val(true)) throw_nonactive() end return nothing @@ -624,29 +714,51 @@ end end ### EnzymeCore.make_zero(!) implementation -function EnzymeCore.make_zero(prev::T, copy_if_inactives::Vararg{Val,M}) where {T,M} - @assert M <= 1 - new = if iszero(M) && !guaranteed_const_nongen(T) && isvectortype(T) # fallback - # guaranteed_const has precedence over isvectortype for consistency with recursive_map +function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} + new = if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback + # IsInactive has precedence over isvectortype for consistency with recursive handler convert(T, zero(prev)) # convert because zero(prev)::T may fail when eltype(T) is abstract else - only(recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactives...)) + _make_zero_inner(prev, args...) end return new::T end -function EnzymeCore.make_zero!(val::T, seens::Vararg{IdDict,M}) where {T,M} - @assert M <= 1 +function EnzymeCore.make_zero!(val::T, args::Vararg{Any,M}) where {T,M} @assert !isscalartype(T) # not appropriate for in-place handler - if iszero(M) && !guaranteed_const_nongen(T) && isvectortype(T) # fallback - # isinactivetype has precedence over isvectortype for consistency with recursive_map + if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback + # IsInactive has precedence over isvectortype for consistency with recursive handler fill!(val, false) else - recursive_map!(seens..., _make_zero!!, (val,), (val,)) + _make_zero_inner!(val, args...) end return nothing end +@inline function _make_zero_inner( + prev::T, copy_if_inactive::Val=Val(false), ::Val{runtime_inactive}=Val(false) +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + news = recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) + return only(news)::T +end + +@inline function _make_zero_inner!( + val::T, ::Val{runtime_inactive}=Val(false) +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + recursive_map!(_make_zero!!, (val,), (val,), Val(false), isinactivetype) + return nothing +end + +@inline function _make_zero_inner!( + val::T, seen::IdDict, ::Val{runtime_inactive}=Val(false) +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + recursive_map!(seen, _make_zero!!, (val,), (val,), Val(false), isinactivetype) + return nothing +end + function _make_zero!!(prev::T) where {T} @assert isvectortype(T) # otherwise infinite loop return (EnzymeCore.make_zero(prev)::T,) @@ -662,10 +774,15 @@ end # alternative entry point for passing custom IdDict function EnzymeCore.make_zero( - ::Type{T}, seen::IdDict, prev::T, copy_if_inactives::Vararg{Val,M} -) where {T,M} - @assert M <= 1 - return only(recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactives...))::T + ::Type{T}, + seen::IdDict, + prev::T, + copy_if_inactive::Val=Val(false), + ::Val{runtime_inactive}=Val(false), +) where {T,runtime_inactive} + isinactivetype = IsInactive{runtime_inactive}() + news = recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) + return only(news)::T end end # module RecursiveMaps diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index fc903e5139..256660b300 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -551,34 +551,36 @@ function test_make_zero() @test a[1] === 1.0 # no mutation of original end @testset "runtime inactive" begin - # verify that MutableWrapper is seen as active + # verify that MutableWrapper is seen as active by both variants a = MutableWrapper(1.0) @assert !EnzymeRules.inactive_type(typeof(a)) - a_makez = make_zero(a) - @test a_makez == MutableWrapper(0.0) + a_makez = make_zero(a, Val(false), Val(false)) + @assert a_makez == MutableWrapper(0.0) + a_makez = make_zero(a, Val(false), Val(true)) + @assert a_makez == MutableWrapper(0.0) # mark MutableWrapper as inactive @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true - # verify that MutableWrapper is seen as inactive and shared/copied according to - # copy_if_inactive - @assert a.x === 1.0 # sanity check - a_makez = @invokelatest make_zero(a) - @test a_makez == a # equal - @assert a.x === 1.0 # sanity check - a_makez = @invokelatest make_zero(a, Val(false)) - @test a_makez === a # identical - @assert a.x === 1.0 # sanity check - a_makez = @invokelatest make_zero(a, Val(true)) - @test a_makez !== a # not identical - @test a_makez == a # but equal + # runtime_inactive == false => redefined inactive_type should have no effect + a_makez = @invokelatest make_zero(a, Val(false), Val(false)) + @test a_makez == MutableWrapper(0.0) + + # runtime_inactive == true => redefined inactive_type should take effect + # MutableWrapper considered inactive and treated according to copy_if_inactive + a_makez = @invokelatest make_zero(a, Val(false), Val(true)) + @test a_makez === a + a_makez = @invokelatest make_zero(a, Val(true), Val(true)) + @test a_makez !== a + @test a_makez == MutableWrapper(1.0) # mark MutableWrapper as active again @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false - # verify that MutableWrapper is seen as active - @assert a.x === 1.0 # sanity check - a_makez = @invokelatest make_zero(a) + # verify that MutableWrapper is seen as active by both variants + a_makez = @invokelatest make_zero(a, Val(false), Val(false)) + @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a, Val(false), Val(true)) @test a_makez == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin @@ -842,26 +844,38 @@ function test_make_zero!() @test a[1] === 0.0 # correct value end @testset "runtime inactive" begin - # verify that MutableWrapper is seen as active + # verify that MutableWrapper is seen as active by both variants a = MutableWrapper(1.0) @assert !EnzymeRules.inactive_type(typeof(a)) - make_zero!(a) - @test a == MutableWrapper(0.0) + make_zero!(a, Val(false)) + @assert a == MutableWrapper(0.0) + a.x = 1.0 + make_zero!(a, Val(true)) + @assert a == MutableWrapper(0.0) # mark MutableWrapper as inactive @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true - # verify that MutableWrapper is seen as inactive + # runtime_inactive == false => redefined inactive_type should have no effect a.x = 1.0 - @invokelatest make_zero!(a) + @invokelatest make_zero!(a, Val(false)) + @test a == MutableWrapper(0.0) + + # runtime_inactive == true => redefined inactive_type should take effect + # MutableWrapper considered inactive and won't be zeroed + a.x = 1.0 + @invokelatest make_zero!(a, Val(true)) @test a == MutableWrapper(1.0) # mark MutableWrapper as active again @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false - # verify that MutableWrapper is seen as active + # verify that MutableWrapper is seen as active by both variants + a.x = 1.0 + @invokelatest make_zero!(a, Val(false)) + @test a == MutableWrapper(0.0) a.x = 1.0 - @invokelatest make_zero!(a) + @invokelatest make_zero!(a, Val(true)) @test a == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin From 25593d646091fdaaf5381cc3e593a2788f3d6669 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sun, 19 Jan 2025 23:06:38 -0800 Subject: [PATCH 07/17] Combine copy_if_inactive and runtime_inactive into a config type Greatly simplifies the interface and setting/sticking to defaults --- lib/EnzymeCore/src/EnzymeCore.jl | 64 +++-- src/typeutils/recursive_add.jl | 72 +++-- src/typeutils/recursive_maps.jl | 458 +++++++++++++------------------ test/recursive_maps.jl | 58 +++- 4 files changed, 319 insertions(+), 333 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 99694c02bd..a84f990747 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -506,29 +506,35 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ + make_zero(prev::T; copy_if_inactive=Val(false), runtime_inactive=Val(false))::T + make_zero(prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}])::T make_zero( - prev::T, ::Val{copy_if_inactive}=Val(false), ::Val{runtime_inactive}=Val(false) + ::Type{T}, seen::IdDict, prev::T; + copy_if_inactive=Val(false), runtime_inactive=Val(false), )::T make_zero( - ::Type{T}, - seen::IdDict, - prev::T, - ::Val{copy_if_inactive}=Val(false), - ::Val{runtime_inactive}=Val(false), + ::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}] )::T -Recursively make a copy of the value `prev::T` in which all differentiable values are -zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any -of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s -instance (the default) or make a copy. +Recursively make a copy of the value `prev::T` in which all differentiable values are zeroed. -The argument `runtime_inactive` specifies whether each constituent type is checked for being -guaranteed inactive at runtime for every call to `make_zero`, or if this can be checked once -at compile-time and reused across multiple calls to `make_zero` and related functions (the -default). Runtime checks are necessary to pick up recently added methods to -`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually -not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have -previously been passed to `make_zero` or related functions. +The argument `copy_if_inactive` specifies what to do if the type `T` or any +of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s +instance (if `Val(false)`, the default) or make a copy (if `Val(true)`). + +The argument `runtime_inactive` specifies whether this function should respect runtime +semantics when determining if a type is guaranteed inactive. If `Val(true)`, changes and +additions to the methods of `EnzymeRules.inactive_type` will be reflected in the behavior of +this function. If `Val(false)`, the inactivity of a type is determined once per Julia +session and reused in every subsequent call with `runtime_inactive=Val(false)`, regardless +of changes to `EnzymeRules.inactive_type` (in technical terms, there will be no invalidation +when changing `EnzymeRules.inactive_type`). Using `runtime_inactive = Val(false)` may be +desireable in interactive sessions, but can sometimes impose a performance penalty and may +in rare cases break gradient compilation when used inside custom rules. Hence +`runtime_inactive = Val(true)` is preferred in non-interactive usage and is the default. + +`copy_if_inactive` and `runtime_inactive` may be given as either positional or keywords +arguments, but not a combination. Extending this method for custom types is rarely needed. If you implement a new type, such as a GPU array type, on which `make_zero` should directly invoke `zero` when the eltype is @@ -539,20 +545,28 @@ scalar, it is sufficient to implement `Base.zero` and make sure your type subtyp function make_zero end """ - make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive}=Val(false))::Nothing + make_zero!(val::T, [seen::IdDict]; runtime_inactive=Val(true))::Nothing + make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive})::Nothing Recursively set a variable's differentiable values to zero. Only applicable for types `T` that are mutable or hold all differentiable values in mutable storage (e.g., `Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over parts of `val` that are guaranteed to be inactive. -The argument `runtime_inactive` specifies whether each constituent type is checked for being -guaranteed inactive at runtime for every call to `make_zero!`, or if this can be checked once -at compile-time and reused across multiple calls to `make_zero!` and related functions (the -default). Runtime checks are necessary to pick up recently added methods to -`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually -not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have -previously been passed to `make_zero!` or related functions. +The argument `runtime_inactive` specifies whether this function should respect runtime +semantics when determining if a type is guaranteed inactive. If `Val(true)`, changes and +additions to the methods of `EnzymeRules.inactive_type` will be reflected in the behavior of +this function. If `Val(false)`, the inactivity of a type is determined once per Julia +session and reused in every subsequent call with `runtime_inactive=Val(false)`, regardless +of changes to `EnzymeRules.inactive_type` (in technical terms, there will be no invalidation +when changing `EnzymeRules.inactive_type`). + +Using `runtime_inactive = Val(false)` may be preferred in interactive sessions, but can +sometimes impose a performance penalty and may in rare cases break gradient compilation when +used inside custom rules. Hence `runtime_inactive = Val(true)` is recommended for +non-interactive usage and is the default. + +`runtime_inactive` may be given as either a positional or a keyword argument. Extending this method for custom types is rarely needed. If you implement a new mutable type, such as a GPU array type, on which `make_zero!` should directly invoke diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index cb421728eb..482b6fc503 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -13,64 +13,62 @@ generalization of `x .+ f.(y)`. The function `f` must return values of the same type as its argument. -The optional argument `forcelhs` takes a callable such that if `forcelhs(S) == true`, values -`zi::S` will be set to `zi = xi`. The default returns true for non-differentiable types, -such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies to -non-differentiable values. +The optional argument `forcelhs` takes a function such that if `forcelhs(S) == true`, values +`zi::S` will be set to `zi = xi`. The default returns true for non-differentiable (inactive) +types, such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies +to non-differentiable values. If a custom callable is passed, it is combined with the +default, as `recursive_add` is not generally capable of traversing inactive objects. """ function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L} function addf(xi::S, yi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((xi + f(yi))::S,) end - return only(recursive_map(addf, Val(1), (x, y), Val(false), forcelhs))::T + config = RecursiveMaps.InactiveConfig(forcelhs) + return only(recursive_map(addf, Val(1), (x, y), config))::T end """ - accumulate_seen!(f, seen::IdDict, ::Val{runtime_inactive}=Val(false)) - accumulate_seen!(f, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive) + accumulate_seen!(f, seen::IdDict; frozen_inactive=Val(false)) + accumulate_seen!(f, seen::IdDict, ::Val{frozen_inactive}) + accumulate_seen!( + f, seen::IdDict, config::RecursiveMaps.InactiveConfig=RecursiveMaps.InactiveConfig() + ) !!! warning Internal function, documented for developer convenience but not covered by semver API stability guarantees -Recursively accumulate from values into keys, generalizing key .+= f.(value), for each -key-value pair in `seen::IdDict` where each key must be a mutable object or non-isbits -vector type instance mappping to another object of the same type and structure. Typically -`seen` is populated by `make_zero` (or some other single-argument invocation of -`recursive_map`), mapping components of its argument to the corresponding component of the -returned value. +Recursively accumulate from values into keys, generalizing `key .+= f.(value)` to arbitrary +types. This accumulation is applied to each key-value pair in `seen::IdDict` where each key +is of a mutable or non-isbits vector type and the corresponding value is of the same type +and structure. Typically `seen` is populated by `make_zero`/`recursive_map`, mapping parts +of its input to the corresponding parts of the returned value. -The recursion stops at instances of types that are themselves cached by `make_zero` -(`recursive_map`), as these objects should have their own entries in `seen`. The recursion -also stops at inactive objects that would not be zeroed by `make_zero`. +The recursion stops at objects of types that are themselves cached by +`make_zero`/`recursive_map`, as these objects should have their own entries in `seen`. The +recursion also stops at inactive objects that would be skipped by `make_zero`/`recursive_map`. -If the optional `::Val{runtime_inactive}` argument was passed to `make_zero`, the same value -should be passed to `accumulate_seen` for consistency. If needed, a custom -`RecursiveMaps.IsInactive` instance can be provided instead. +If the optional argument `::Val{runtime_inactive}` was passed to `make_zero`, or +`config::RecursiveMaps.InactiveConfig` was passed to `recursive_map`, the same value should +be passed to `accumulate_seen` to enzure consistency. """ -function accumulate_seen!( - f::F, seen::IdDict, ::Val{runtime_inactive}=Val(false) -) where {F,runtime_inactive} - accumulate_seen!(f, seen, RecursiveMaps.IsInactive{runtime_inactive}()) +function accumulate_seen! end + +function accumulate_seen!(f::F, seen::IdDict, args::Vararg{Any,M}; kws...) where {F,M} + accumulate_seen!(f, seen, RecursiveMaps.make_zero!_config(args...; kws...)) return nothing end -function accumulate_seen!( - f::F, seen::IdDict, isinactivetype::RecursiveMaps.IsInactive -) where {F} - isinactive_or_cachedtype = RecursiveMaps.IsInactive( - isinactivetype, RecursiveMaps.iscachedtype - ) +function accumulate_seen!(f::F, seen::IdDict, config::RecursiveMaps.InactiveConfig) where {F} + cachedconfig = RecursiveMaps.InactiveConfig(config, RecursiveMaps.iscachedtype) for (k, v) in seen - _accumulate_seen_item!(f, k, v, isinactivetype, isinactive_or_cachedtype) + _accumulate_seen_item!(f, k, v, config, cachedconfig) end return nothing end -function _accumulate_seen_item!( - f::F, k::T, v::T, isinactivetype, isinactive_or_cachedtype -) where {F,T} +function _accumulate_seen_item!(f::F, k::T, v::T, config, cachedconfig) where {F,T} function addf!!(ki::S, vi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((ki .+ f.(vi))::S,) @@ -82,11 +80,9 @@ function _accumulate_seen_item!( ki .+= f.(vi) return (ki::S,) end - RecursiveMaps.check_nonactive(T, isinactivetype) - if !isinactivetype(T) - newks = RecursiveMaps.recursive_map_inner( - nothing, addf!!, (k,), (k, v), Val(false), isinactive_or_cachedtype - ) + RecursiveMaps.check_nonactive(T, config) + if !RecursiveMaps.isinactivetype(T, config) + newks = RecursiveMaps.recursive_map_inner(nothing, addf!!, (k,), (k, v), cachedconfig) @assert only(newks) === k end return nothing diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index 49acc04735..15f5d9dd3e 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -4,82 +4,81 @@ using EnzymeCore: EnzymeCore, isvectortype, isscalartype using ..Compiler: guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, guaranteed_nonactive_nongen -### IsInactive: helper for creating consistent inactive/nonactive type checkers +### Config type for setting inactive/nonactive options """ - isinactivetype = IsInactive{runtime::Bool}(extra=(T -> false)) - isinactivetype = IsInactive(isinactivetype::IsInactive, extra) + config = InactiveConfig( + extra=(T -> false); copy_if_inactive=Val(false), runtime_inactive=Val(false) + ) + config = InactiveConfig{copy_if_inactive::Bool,runtime_inactive::Bool}(extra) + newconfig = InactiveConfig(config::InactiveConfig, extra) !!! warning Internal type, documented for developer convenience but not covered by semver API stability guarantees -Create a callable `isinactivetype` such that `isinactivetype(T) == true` if the type `T` is -non-differentiable, that is, if differentiable values can never be reached from any instance -of the type (that is, the activity state of `T` is `AnyState`). +Config type for specifying which parts of objects should be skipped by `recursive_map{!}`. -The callable takes an optional argument `Val(nonactive::Bool)`, such that the full signature -is - -```julia -isinactivetype(::Type{T}, ::Val{nonactive}=Val(false))::Bool -``` - -Setting `nonactive == true` selects for _nonactive_ types, which is a superset of inactive -types that also includes types `T` where every differentiable value can be mutated without -creating a new instance of `T` (that is, the activity state of `T` is either `AnyState` or -`DupState`). +At a minimum, parts that Enzyme always considers inactive are skipped. An inactive type is a +type for which Enzyme can prove that a differentiable value can never be reached from any +instance of the type. The optional argument `extra` takes a function defining additional types that should be -treated as inactive regardless of their nominal activity state; that is, - -```julia -IsInactive{runtime}(extra)(T, args...) == IsInactive{runtime}()(T, args...) || extra(T) -``` - -The constructor `IsInactive(isinactivetype::IsInactive{runtime}, extra)` can be used to -extend an existing instance `isinactivetype::IsInactive` with an additional `extra` -function, and is more or less equivalent to -`IsInactive{runtime}(T -> isinactivetype.extra(T) || extra(T))`. - -The type parameter `runtime` specifies whether the activity state of a type is queried at -runtime every time the callable is invoked (`true`), or if compile-time queries from earlier -calls can be reused (`false`). Runtime querying is necessary to pick up recently added -methods to `EnzymeRules.inactive_type`, but may incur a significant performance penalty and -is usually not needed unless `EnzymeRules.inactive_type` is extended interactively for types -that have previously been passed to an instance of `IsInactive{false}`. +skipped regardless of their nominal activity. `extra` should be a plain function +or callable of a singleton type, not a closure or otherwise stateful callable; this is to +ensure that an `InactiveConfig` instance is fully specified by its type. + +The parameter `copy_if_inactive` specifies whether `recursive_map{!}` should share (if +`Val(false)`, the default) or deep-copy (if `Val(true)`) inactive/skipped parts from inputs +to outputs. + +The parameter `runtime_inactive` specifies whether `recursive_map{!}` should respect runtime +semantics when determining if a type is guaranteed inactive. If `Val(false)`, guaranteed +inactivity is determined once during compilation of the internal generated function +`active_reg_nothrow`, and won't be invalidated by subsequent changes to the +`EnzymeRules.inactive_type` method table. If `Val(true)`, the generated function is not used +and changes to `EnzymeRules.inactive_type` are picked up through invalidation as usual. + +Using `runtime_inactive = Val(false)` may be preferred in interactive sessions, but +performance may sometimes suffer if the activity states of all types cannot be resolved at +compile time, and in some cases this mode has been observed to break gradient compilation +when `recursive_map{!}` is used inside custom rules. Hence `runtime_inactive = Val(true)` is +recommended for non-interactive usage and is the default. + +The updating constructor `InactiveConfig(config::InactiveConfig, extra)` returns a new +config that extends `config` with an additional `extra` function. """ -struct IsInactive{runtime,F} - extra::F - function IsInactive{runtime}( - extra::F=(@nospecialize(T) -> (@inline; false)) - ) where {runtime,F} - return new{runtime::Bool,F}(extra) +struct InactiveConfig{copy_if_inactive,runtime_inactive,E} + extra::E + function InactiveConfig{C,R}(extra::E) where {C,R,E} + @assert Base.issingletontype(E) + return new{C::Bool,R::Bool,E}(extra) end end -function IsInactive(isinactivetype::IsInactive{runtime}, extra::F) where {runtime,F} - combinedextra(::Type{T}) where {T} = (isinactivetype.extra(T) || extra(T)) - return IsInactive{runtime}(combinedextra) +function InactiveConfig( + extra::E=(_ -> (@nospecialize; false)); + copy_if_inactive::Val{C}=Val(false), runtime_inactive::Val{R}=Val(false), +) where {E,C,R} + return InactiveConfig{C,R}(extra) end -@inline function (f::IsInactive{runtime,F})( - ::Type{T}, ::Val{nonactive}=Val(false) -) where {runtime,F,T,nonactive} - if runtime - # evaluate f.extra first, as guaranteed_*_nongen may incur runtime dispatch - if nonactive - return f.extra(T) || guaranteed_nonactive_nongen(T, nothing) - else - return f.extra(T) || guaranteed_const_nongen(T, nothing) - end - else - # evaluate guaranteed_* first, as these are always known at compile time - if nonactive - return guaranteed_nonactive(T) || f.extra(T) - else - return guaranteed_const(T) || f.extra(T) - end - end +function InactiveConfig(config::InactiveConfig{C,R}, extra::E) where {C,R,E} + @inline combinedextra(::Type{T}) where {T} = (config.extra(T) || extra(T)) + return InactiveConfig{C,R}(combinedextra) +end + +function isinactivetype(::Type{T}, config::InactiveConfig{C,false}) where {T,C} + return guaranteed_const(T) || config.extra(T) # call guaranteed_const first, as this is a constant at runtime +end +function isinactivetype(::Type{T}, config::InactiveConfig{C,true}) where {T,C} + return config.extra(T) || guaranteed_const_nongen(T, nothing) # call config.extra first, as guaranteed_const_nongen may incur runtime dispatch +end + +function isnonactivetype(::Type{T}, config::InactiveConfig{C,false}) where {T,C} + return guaranteed_nonactive(T) || config.extra(T) # call guaranteed_const first, as this is a constant at runtime +end +function isnonactivetype(::Type{T}, config::InactiveConfig{C,true}) where {T,C} + return config.extra(T) || guaranteed_nonactive_nongen(T, nothing) # call config.extra first, as guaranteed_nonactive_nongen may incur runtime dispatch end ### traits defining active leaf types for recursive_map @@ -101,44 +100,41 @@ end f, ::Val{Nout} xs::NTuple{Nin,T}, - ::Val{copy_if_inactive}=Val(false), - isinactivetype=IsInactive{false}(), + config::InactiveConfig=InactiveConfig(), )::T newys = recursive_map( [seen::Union{Nothing,IdDict},] f, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - ::Val{copy_if_inactive}=Val(false), - isinactivetype=IsInactive{false}(), + config::InactiveConfig=InactiveConfig(), )::T !!! warning Internal function, documented for developer convenience but not covered by semver API stability guarantees -Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping the -function `f` over every differentiable value encountered and building `Nout` new objects +Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping +the function `f` over every differentiable value encountered and building `Nout` new objects `(y1::T, ...)` from the resulting values `(y1_i, ...) = f(x1_i, ..., xNin_i)`. Only `Nout == 1` and `Nout == 2` are supported. The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered -differentiable leaf nodes at which recursion terminates and `f` is invoked. See the -docstring for [`EnzymeCore.isvectortype`](@ref) and the related -[`EnzymeCore.isscalartype`](@ref) for more information. +leaf nodes at which to terminate recursion invoke `f`. See the docstring for +[`EnzymeCore.isvectortype`](@ref) and the related [`EnzymeCore.isscalartype`](@ref) for more +information. A tuple of existing objects `ys = (y1::T, ...)` can be passed, in which case the `ys` are updated "partially-in-place": any parts of the `ys` that are mutable or non-differentiable are reused in the returned object tuple `newys`, while immutable differentiable parts are -handled out-of-place as if the `ys` were not passed (this can be seen as a recursive -generalization of the BangBang.jl idiom). If `T` itself is a mutable type, the `ys` are -modified in-place and returned, such that `newys === ys`. +handled out-of-place as if the `ys` were not passed. If `T` itself is a mutable type, the +`ys` are modified in-place and returned, such that `newys === ys`. -The recursion and mapping operates on the structure of `T` as defined by struct fields and -plain array elements, not on the values provided through an iteration or array interface. -For example, given a structured matrix wrapper or sparse array type, this function recurses -into the struct type and the plain arrays held within, rather than operating on the array -that the type notionally represents. +The recursion and mapping operate on the structure of `T` as defined by struct fields and +plain array elements, not on the values provided through iteration or array interfaces. For +example, given a structured matrix wrapper or sparse array type, this function recurses into +the struct type and operates on the plain arrays held within, rather than operating on the +array that the type notionally represents. # Arguments @@ -147,10 +143,9 @@ that the type notionally represents. that of the `xs`, including cycles (i.e., recursive substructures) and multiple paths to the same objects. If not provided, an `IdDict` will be allocated internally if required. - If `nothing` is provided, object identity is not tracked. In this case, objects with - multiple references are duplicated such that the `ys`s object reference graph becomes a - tree, cycles lead to infinite recursion and stack overflow, and `copy_if_inactive == true` - will likely cause errors. This is useful only in specific cases. + If `nothing` is provided, object identity is tracking is turned off. In this case, objects + with multiple references are duplicated such that the `ys`s object reference graph becomes + a tree, but cycles will result in infinite recursion and stack overflow. * `f`: Function mapping leaf nodes within the `xs` to the corresponding leaf nodes in the `ys`, that is, `(y1_i, ...) = f(x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}`. The function @@ -161,54 +156,44 @@ that the type notionally represents. non-scalar type `U`, `f` should also have a partially-in-place method `(newy1_i, ...) === f(y1_i::U, ..., yNout_i::U, x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}` that modifies and reuses any mutable parts of the `yj_i`; in particular, if `U` is a - mutable type, this method should return `newyj_i === yj_i`. If a non-isbits type `U` - should always be handled using the out-of-place signature, extend - [`EnzymeCore.isscalartype`](@ref) such that `isscalartype(U) == true`. + mutable type, this method should return `newyj_i === yj_i`. + + If a non-isbits leaf type `U` must always be handled using the out-of-place signature, + define the method `EnzymeCore.isscalartype(::Type{U}) = true`. See [`EnzymeCore.isvectortype`](@ref) and [`EnzymeCore.isscalartype`](@ref) for more details about leaf types and scalar types. * `::Val{Nout}` or `ys::NTuple{Nout,T}`: For out-of-place operation, pass `Val(Nout)` where - `Nout in (1, 2)` is the length of the tuple returned by `f`, that is, the length of the - expected return value `ys` (this is required; `Nout` never inferred). For - partially-in-place operation, pass the existing tuple `ys::NTuple{Nout,T}` containing the - values to be modified. + `Nout in (1, 2)` matches the length of the tuple returned by `f`. For partially-in-place + operation, pass the existing tuple `ys::NTuple{Nout,T}` containing the values to be + modified. -* `xs::NTuple{N,T}`: Tuple of `N` objects of the same type `T` over which `f` is mapped. +* `xs::NTuple{N,T}`: Tuple of `N` objects of the same type `T`. The first object `x1 = first(xs)` is the reference for graph structure and non-differentiable values when constructing the returned object. In particular: - * When `ys` is not passed, the returned objects take any non-differentiable parts from - `x1`. (When `ys` is passed, its non-differentiable parts are kept unchanged in the - returned object, unless they are not initialized, in which case they are taken from - `x1`.) + * When `ys` is not passed, the returned `ys` take any non-differentiable parts from `x1`. + * When `ys` is passed, its non-differentiable parts are kept unchanged, unless they are + uninitialized, in which case they are taken from `x1`. * The graph of object references in `x1` is the one which is reproduced in the returned object. For each instance of multiple paths and cycles within `x1`, the same structure must be present in the other objects `x2, ..., xN`, otherwise the corresponding values - in the `ys` would not be uniquely defined. However, `x2, ..., xN` may contain multiple - paths or cycles that are not present in `x1`; these do not affect the structure of `ys`. + in the `ys` would not be uniquely defined. However, `x2, ..., xN` may contain additional + converging paths or cycles that are not present in `x1`; these do not affect the `ys`. * If any values within `x1` are not initialized (that is, struct fields are undefined or array elements are unassigned), they are left uninitialized in the returned object. If any such values are mutable and `ys` is passed, the corresponding value in `y` must not - already be initialized, since initialized values cannot be nulled. Conversely, for every + already be initialized (initialized values cannot be nulled). Conversely, for every value in `x1` that is initialized, the corresponding values in `x2, ..., xN` must also be initialized, such that the corresponding values of the `ys` can be computed (however, - values in `x2, ..., xN` can be initialized while the corresponding value in `x1` is not; - such values are ignored.) - -* `::Val{copy_if_inactive::Bool}` (optional): When a non-differentiable part of `x1` is - included in the returned object, either because an object tuple `ys` is not passed or this - part of the `ys` is not initialized, `copy_if_inactive` determines how: if - `copy_if_inactive == false`, it is shared as `yj_i = x1_i`; if `copy_if_inactive == true`, - it is deep-copied, more-or-less as `yj_i = deepcopy(x1_i)` (the difference is that when - `x1` has several non-differentiable parts, object identity is tracked across the multiple - deep-copies such that the object reference graph is reproduced also within the inactive - parts.) - -* `isinactivetype` (optional): Callable mapping types to `Bool` to determine whether the - type should be treated according to `copy_if_inactive` (`true`) or recursed into (`false`). - The [`IsInactive`](@ref) type is a helper for obtaining a callable with relevant semantics, - but any callable that maps types to `true` or `false` can be used. + `x2, ..., xN` may have initialized values where `x1` has uninitialized values). + +* `config::InactiveConfig` (optional): Config object detailing how to deal with + non-differentiable (inactive) parts. The config specifies whether non-differentiable parts + should be shared or deep-copied from `x1` to the `ys`, and whether any additional types + should be skipped in addition to those Enzyme always considers inactive. See + [`InactiveConfig`](@ref) for details. """ function recursive_map end @@ -219,19 +204,15 @@ const YS{Nout,T} = Union{Val{Nout},NTuple{Nout,T}} ## main entry point: set default arguments, allocate IdDict if needed, exit early if possible function recursive_map( - f::F, - ys::YS{Nout,T}, - xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val(false), - isinactivetype::L=IsInactive{false}(), -) where {F,Nout,Nin,T,L} + f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config::InactiveConfig=InactiveConfig() +) where {F,Nout,Nin,T} check_nout(ys) - newys = if isinactivetype(T) - recursive_map_inactive(nothing, ys, xs, copy_if_inactive) + newys = if isinactivetype(T, config) + recursive_map_inactive(nothing, ys, xs, config) elseif isvectortype(T) || isbitstype(T) - recursive_map_inner(nothing, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_inner(nothing, f, ys, xs, config) else - recursive_map_inner(IdDict(), f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_inner(IdDict(), f, ys, xs, config) end return newys::NTuple{Nout,T} end @@ -242,47 +223,46 @@ function recursive_map( f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val(false), - isinactivetype::L=IsInactive{false}(), -) where {F,Nout,Nin,T,L} + config::InactiveConfig=InactiveConfig(), +) where {F,Nout,Nin,T} # determine whether to continue recursion, copy/share, or retrieve from cache check_nout(ys) - newys = if isinactivetype(T) - recursive_map_inactive(seen, ys, xs, copy_if_inactive) + newys = if isinactivetype(T, config) + recursive_map_inactive(seen, ys, xs, config) elseif isbitstype(T) # no object identity to to track in this branch - recursive_map_inner(nothing, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_inner(nothing, f, ys, xs, config) elseif hascache(seen, xs) getcached(seen, Val(Nout), xs) else - recursive_map_inner(seen, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_inner(seen, f, ys, xs, config) end return newys::NTuple{Nout,T} end @inline function recursive_map_inner( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nout,Nin,T,L} + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} # forward to appropriate handler for leaf vs. mutable vs. immutable type @assert !isabstracttype(T) @assert isconcretetype(T) newys = if isvectortype(T) recursive_map_leaf(seen, f, ys, xs) elseif ismutabletype(T) - recursive_map_mutable(seen, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_mutable(seen, f, ys, xs, config) else - recursive_map_immutable(seen, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_immutable(seen, f, ys, xs, config) end return newys::NTuple{Nout,T} end @inline function recursive_map_mutable( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nout,Nin,T,L} + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} @assert ismutabletype(T) if !hasvalues(ys) && !(T <: DenseArray) && all(isbitstype, fieldtypes(T)) # fast path for out-of-place handling when all fields are bitstypes, which rules # out undefined fields and circular references - newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + newys = recursive_map_new(seen, f, ys, xs, config) maybecache!(seen, newys, xs) else newys = if hasvalues(ys) @@ -292,71 +272,59 @@ end ntuple(_ -> (@inline; _similar(x1)), Val(Nout)) end maybecache!(seen, newys, xs) - recursive_map_mutable_inner!(seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_mutable_inner!(seen, f, newys, ys, xs, config) end return newys::NTuple{Nout,T} end @inline function recursive_map_mutable_inner!( - seen, - f::F, - newys::NTuple{Nout,T}, - ys::YS{Nout,T}, - xs::NTuple{Nin,T}, - copy_if_inactive, - isinactivetype::L, -) where {F,Nout,Nin,T<:DenseArray,L} + seen, f::F, newys::NTuple{Nout,T}, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T<:DenseArray} if (Nout == 1) && isbitstype(eltype(T)) newy = only(newys) if hasvalues(ys) y = only(ys) broadcast!(newy, y, xs...) do y_i, xs_i... - only(recursive_map(nothing, f, (y_i,), xs_i, copy_if_inactive, isinactivetype)) + only(recursive_map(nothing, f, (y_i,), xs_i, config)) end else broadcast!(newy, xs...) do xs_i... - only(recursive_map(nothing, f, Val(1), xs_i, copy_if_inactive, isinactivetype)) + only(recursive_map(nothing, f, Val(1), xs_i, config)) end end else @inbounds for i in eachindex(newys..., xs...) - recursive_map_item!(i, seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_item!(i, seen, f, newys, ys, xs, config) end end return nothing end @generated function recursive_map_mutable_inner!( - seen, - f::F, - newys::NTuple{Nout,T}, - ys::YS{Nout,T}, - xs::NTuple{Nin,T}, - copy_if_inactive, - isinactivetype::L, -) where {F,Nout,Nin,T,L} + seen, f::F, newys::NTuple{Nout,T}, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} return quote @inline Base.Cartesian.@nexprs $(fieldcount(T)) i -> @inbounds begin - recursive_map_item!(i, seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_item!(i, seen, f, newys, ys, xs, config) end return nothing end end @inline function recursive_map_immutable( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nout,Nin,T,L} + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} @assert !ismutabletype(T) nf = fieldcount(T) if nf == 0 # nothing to do (also no known way to hit this branch) - newys = recursive_map_inactive(seen, ys, xs, Val(false)) + newys = recursive_map_inactive(seen, ys, xs, config) else newys = if isinitialized(first(xs), nf) # fast path when all fields are defined check_allinitialized(Base.tail(xs), nf) - recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_new(seen, f, ys, xs, config) else - recursive_map_immutable_inner(seen, f, ys, xs, copy_if_inactive, isinactivetype) + recursive_map_immutable_inner(seen, f, ys, xs, config) end # maybecache! _should_ be a no-op here; call it anyway for consistency maybecache!(seen, newys, xs) @@ -365,8 +333,8 @@ end end @generated function recursive_map_immutable_inner( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nout,Nin,T,L} + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} nf = fieldcount(T) return quote @inline @@ -375,9 +343,7 @@ end Base.Cartesian.@nexprs $(nf - 1) i -> begin # unrolled loop over struct fields @inbounds if isinitialized(x1, i) check_allinitialized(xtail, i) - newys_i = recursive_map_item( - i, seen, f, ys, xs, copy_if_inactive, isinactivetype - ) + newys_i = recursive_map_item(i, seen, f, ys, xs, config) Base.Cartesian.@nexprs $Nout j -> (fields[j][i] = newys_i[j]) else return new_structvs(T, fields, i - 1) @@ -389,14 +355,14 @@ end end @generated function recursive_map_new( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nout,Nin,T,L} + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} # direct construction of fully initialized non-cyclic structs nf = fieldcount(T) return quote @inline Base.Cartesian.@nexprs $nf i -> @inbounds begin - newys_i = recursive_map_item(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) + newys_i = recursive_map_item(i, seen, f, ys, xs, config) end newys = Base.@ntuple $Nout j -> begin $(Expr(:splatnew, :T, :(Base.@ntuple $nf i -> newys_i[j]))) @@ -406,18 +372,11 @@ end end Base.@propagate_inbounds function recursive_map_item!( - i, - seen, - f::F, - newys::NTuple{Nout,T}, - ys::YS{Nout,T}, - xs::NTuple{Nin,T}, - copy_if_inactive, - isinactivetype::L, -) where {F,Nout,Nin,T,L} + i, seen, f::F, newys::NTuple{Nout,T}, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} if isinitialized(first(xs), i) check_allinitialized(Base.tail(xs), i) - newys_i = recursive_map_item(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) + newys_i = recursive_map_item(i, seen, f, ys, xs, config) setitems!(newys, i, newys_i) elseif hasvalues(ys) check_allinitialized(ys, i, false) @@ -426,46 +385,38 @@ Base.@propagate_inbounds function recursive_map_item!( end Base.@propagate_inbounds function recursive_map_item( - i, seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L -) where {F,Nout,Nin,T,L} + i, seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config +) where {F,Nout,Nin,T} # recurse into the xs and apply recursive_map to items with index i xs_i = getitems(xs, i) newys_i = if hasvalues(ys) && isinitialized(first(ys), i) check_allinitialized(Base.tail(ys), i) ys_i = getitems(ys, i) - recursive_map_barrier!!(seen, f, ys_i..., copy_if_inactive, isinactivetype, xs_i...) + recursive_map_barrier!!(seen, f, ys_i..., config, xs_i...) else - recursive_map_barrier(seen, f, Val(Nout), copy_if_inactive, isinactivetype, xs_i...) + recursive_map_barrier(seen, f, Val(Nout), config, xs_i...) end return newys_i end # function barriers such that abstractly typed items trigger minimal runtime dispatch function recursive_map_barrier( - seen, f::F, ::Val{Nout}, copy_if_inactive::Val, isinactivetype::L, xs_i::Vararg{ST,Nin} -) where {F,Nout,Nin,ST,L} - return recursive_map( - seen, f, Val(Nout), xs_i, copy_if_inactive, isinactivetype - )::NTuple{Nout,ST} + seen, f::F, ::Val{Nout}, config::InactiveConfig, xs_i::Vararg{ST,Nin} +) where {F,Nout,Nin,ST} + return recursive_map(seen, f, Val(Nout), xs_i, config)::NTuple{Nout,ST} end function recursive_map_barrier!!( - seen, f::F, y_i::ST, copy_if_inactive::Val, isinactivetype::L, xs_i::Vararg{ST,Nin} -) where {F,Nin,ST,L} - return recursive_map(seen, f, (y_i,), xs_i, copy_if_inactive, isinactivetype)::NTuple{1,ST} + seen, f::F, y_i::ST, config::InactiveConfig, xs_i::Vararg{ST,Nin} +) where {F,Nin,ST} + return recursive_map(seen, f, (y_i,), xs_i, config)::NTuple{1,ST} end function recursive_map_barrier!!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - seen, - f::F, - y1_i::ST, - y2_i::ST, - copy_if_inactive::Val, - isinactivetype::L, - xs_i::Vararg{ST,Nin} -) where {F,Nin,ST,L} + seen, f::F, y1_i::ST, y2_i::ST, config::InactiveConfig, xs_i::Vararg{ST,Nin} +) where {F,Nin,ST} ys_i = (y1_i, y2_i) - return recursive_map(seen, f, ys_i, xs_i, copy_if_inactive, isinactivetype)::NTuple{2,ST} + return recursive_map(seen, f, ys_i, xs_i, config)::NTuple{2,ST} end ## recursion base case handlers @@ -486,13 +437,13 @@ end end @inline function recursive_map_inactive( - _, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive} + _, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::InactiveConfig{copy_if_inactive} ) where {Nout,Nin,T,copy_if_inactive} return ys::NTuple{Nout,T} end @inline function recursive_map_inactive( - seen, ::Val{Nout}, (x1,)::NTuple{Nin,T}, ::Val{copy_if_inactive} + seen, ::Val{Nout}, (x1,)::NTuple{Nin,T}, ::InactiveConfig{copy_if_inactive} ) where {Nout,Nin,T,copy_if_inactive} @inline y = if copy_if_inactive && !isbitstype(T) @@ -514,8 +465,7 @@ end f!!, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - ::Val{copy_if_inactive}=Val(false), - isinactivetype::IsInactive=IsInactive{false}(), + isinactivetype::InactiveConfig=InactiveConfig(), )::Nothing !!! warning @@ -529,23 +479,14 @@ in-place with the resulting values. This is a simple wrapper that verifies that `T` is a type where all differentiable values can be updated in-place, calls `recursive_map`, and verifies that the returned value is indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. - -Note that this wrapper only supports instances of [`IsInactive`](@ref) for the -`isinactivetype` argument, as this is the only way we can insure consistency between the -upfront compatibility check and actual behavior. If this is not appropriate, use -`recursive_map` directly. """ function recursive_map! end function recursive_map!( - f!!::F, - ys::NTuple{Nout,T}, - xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val(false), - isinactivetype::IsInactive=IsInactive{false}(), + f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, config::InactiveConfig=InactiveConfig() ) where {F,Nout,Nin,T} - check_nonactive(T, isinactivetype) - newys = recursive_map(f!!, ys, xs, copy_if_inactive, isinactivetype) + check_nonactive(T, config) + newys = recursive_map(f!!, ys, xs, config) @assert newys === ys return nothing end @@ -555,11 +496,10 @@ function recursive_map!( f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, - copy_if_inactive::Val=Val(false), - isinactivetype::IsInactive=IsInactive{false}(), + config::InactiveConfig=InactiveConfig(), ) where {F,Nout,Nin,T} - check_nonactive(T, isinactivetype) - newys = recursive_map(seen, f!!, ys, xs, copy_if_inactive, isinactivetype) + check_nonactive(T, config) + newys = recursive_map(seen, f!!, ys, xs, config) @assert newys === ys return nothing end @@ -685,8 +625,8 @@ Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized=true) = return nothing end -@inline function check_nonactive(::Type{T}, isinactivetype::IsInactive) where {T} - if !isinactivetype(T, Val(true)) +@inline function check_nonactive(::Type{T}, config) where {T} + if !isnonactivetype(T, config) throw_nonactive() end return nothing @@ -714,51 +654,51 @@ end end ### EnzymeCore.make_zero(!) implementation -function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} - new = if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback - # IsInactive has precedence over isvectortype for consistency with recursive handler - convert(T, zero(prev)) # convert because zero(prev)::T may fail when eltype(T) is abstract +@inline function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}; kws...) where {T,M} + config = make_zero_config(args...; kws...) + new = if iszero(M) && isempty(kws) && !isinactivetype(T, config) && isvectortype(T) # fallback + # isinactivetype precedes over isvectortype for consistency with recursive handler + convert(T, zero(prev)) # convert because zero(prev)::T may not hold when eltype(T) is abstract else - _make_zero_inner(prev, args...) + only(recursive_map(_make_zero!!, Val(1), (prev,), config))::T end return new::T end -function EnzymeCore.make_zero!(val::T, args::Vararg{Any,M}) where {T,M} +@inline function EnzymeCore.make_zero!(val::T, args::Vararg{Any,M}; kws...) where {T,M} @assert !isscalartype(T) # not appropriate for in-place handler - if iszero(M) && !IsInactive{false}()(T) && isvectortype(T) # fallback - # IsInactive has precedence over isvectortype for consistency with recursive handler + if iszero(M) && isempty(kws) && !isinactivetype(T, make_zero!_config()) && isvectortype(T) # fallback + # isinactivetype precedes over isvectortype for consistency with recursive handler fill!(val, false) else - _make_zero_inner!(val, args...) + _make_zero_inner!(val, args...; kws...) end return nothing end -@inline function _make_zero_inner( - prev::T, copy_if_inactive::Val=Val(false), ::Val{runtime_inactive}=Val(false) -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - news = recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) - return only(news)::T +@inline function _make_zero_inner!(val, args::Vararg{Any,M}; kws...) where {M} + return recursive_map!(_make_zero!!, (val,), (val,), make_zero!_config(args...; kws...)) end - -@inline function _make_zero_inner!( - val::T, ::Val{runtime_inactive}=Val(false) -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - recursive_map!(_make_zero!!, (val,), (val,), Val(false), isinactivetype) - return nothing +@inline function _make_zero_inner!(val, seen::IdDict, args::Vararg{Any,M}; kws...) where {M} + config = make_zero!_config(args...; kws...) + return recursive_map!(seen, _make_zero!!, (val,), (val,), config) end -@inline function _make_zero_inner!( - val::T, seen::IdDict, ::Val{runtime_inactive}=Val(false) -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - recursive_map!(seen, _make_zero!!, (val,), (val,), Val(false), isinactivetype) - return nothing +# map make_zero(!) args/kws to config +@inline make_zero_config(C) = InactiveConfig(; copy_if_inactive=C) +@inline make_zero_config(C, R) = InactiveConfig(; copy_if_inactive=C, runtime_inactive=R) +@inline make_zero_config(; kws...) = InactiveConfig(; kws...) + +@inline make_zero!_config(R) = InactiveConfig(; runtime_inactive=R) +@inline function make_zero!_config(; runtime_inactive=nothing) + if isnothing(runtime_inactive) + return InactiveConfig() + else + return InactiveConfig(; runtime_inactive) + end end +# the mapped function: assert leaf type and call back into single-arg make_zero(!) function _make_zero!!(prev::T) where {T} @assert isvectortype(T) # otherwise infinite loop return (EnzymeCore.make_zero(prev)::T,) @@ -773,15 +713,11 @@ function _make_zero!!(val::T, _val::T) where {T} end # alternative entry point for passing custom IdDict -function EnzymeCore.make_zero( - ::Type{T}, - seen::IdDict, - prev::T, - copy_if_inactive::Val=Val(false), - ::Val{runtime_inactive}=Val(false), -) where {T,runtime_inactive} - isinactivetype = IsInactive{runtime_inactive}() - news = recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactive, isinactivetype) +@inline function EnzymeCore.make_zero( + ::Type{T}, seen::IdDict, prev::T, args::Vararg{Any,M}; kws... +) where {T,M} + config = make_zero_config(args...; kws...) + news = recursive_map(seen, _make_zero!!, Val(1), (prev,), config) return only(news)::T end diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 256660b300..787b0f58ec 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -434,14 +434,16 @@ function test_make_zero() end end end - @testset "copy_if_inactive $value" for (value, args) in [ - ("unspecified", ()), - ("= false", (Val(false),)), - ("= true", (Val(true),)), + @testset "copy_if_inactive $value" for (value, args, kwargs) in [ + ("unspecified", (), (;)), + ("= false", (Val(false),), (;)), + ("= false (kwarg)", (), (; copy_if_inactive=Val(false))), + ("= true", (Val(true),), (;)), + ("= true (kwarg)", (), (; copy_if_inactive=Val(true))), ] a = [1.0] w = Any[a, inactivearr, inactivearr] - w_makez = make_zero(w, args...) + w_makez = make_zero(w, args...; kwargs...) @test typeof(w_makez) === typeof(w) # correct type @test typeof(w_makez[1]) === typeof(a) # correct type @test w_makez[1] == [0.0] # correct value @@ -451,7 +453,7 @@ function test_make_zero() @test w[2] === w[3] # no mutation of original @test w[2] === inactivearr # no mutation of original @test inactivearr[1] === inactivetup # no mutation of original - if args == (Val(true),) + if (args == (Val(true),)) || (kwargs == (; copy_if_inactive=Val(true))) @test typeof(w_makez[2]) === typeof(inactivearr) # correct type @test w_makez[2] == inactivearr # correct value @test w_makez[2][1] !== inactivetup # correct identity @@ -550,14 +552,18 @@ function test_make_zero() @test v.data === a # no mutation of original @test a[1] === 1.0 # no mutation of original end - @testset "runtime inactive" begin + @testset "runtime_inactive" begin # verify that MutableWrapper is seen as active by both variants a = MutableWrapper(1.0) @assert !EnzymeRules.inactive_type(typeof(a)) a_makez = make_zero(a, Val(false), Val(false)) @assert a_makez == MutableWrapper(0.0) + a_makez = make_zero(a; runtime_inactive=Val(false)) + @assert a_makez == MutableWrapper(0.0) a_makez = make_zero(a, Val(false), Val(true)) @assert a_makez == MutableWrapper(0.0) + a_makez = make_zero(a; runtime_inactive=Val(true)) + @assert a_makez == MutableWrapper(0.0) # mark MutableWrapper as inactive @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true @@ -565,14 +571,25 @@ function test_make_zero() # runtime_inactive == false => redefined inactive_type should have no effect a_makez = @invokelatest make_zero(a, Val(false), Val(false)) @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a; runtime_inactive=Val(false)) + @test a_makez == MutableWrapper(0.0) # runtime_inactive == true => redefined inactive_type should take effect # MutableWrapper considered inactive and treated according to copy_if_inactive a_makez = @invokelatest make_zero(a, Val(false), Val(true)) @test a_makez === a + a_makez = @invokelatest make_zero( + a; copy_if_inactive=Val(false), runtime_inactive=Val(true) + ) + @test a_makez === a a_makez = @invokelatest make_zero(a, Val(true), Val(true)) @test a_makez !== a @test a_makez == MutableWrapper(1.0) + a_makez = @invokelatest make_zero( + a; copy_if_inactive=Val(true), runtime_inactive=Val(true) + ) + @test a_makez !== a + @test a_makez == MutableWrapper(1.0) # mark MutableWrapper as active again @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false @@ -580,8 +597,12 @@ function test_make_zero() # verify that MutableWrapper is seen as active by both variants a_makez = @invokelatest make_zero(a, Val(false), Val(false)) @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a; runtime_inactive=Val(false)) + @test a_makez == MutableWrapper(0.0) a_makez = @invokelatest make_zero(a, Val(false), Val(true)) @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a; runtime_inactive=Val(true)) + @test a_makez == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin @testset "array w inactive/active/mutable/unassigned" begin @@ -843,15 +864,22 @@ function test_make_zero!() @test v.data === a # preserved identity @test a[1] === 0.0 # correct value end - @testset "runtime inactive" begin + @testset "runtime_inactive" begin # verify that MutableWrapper is seen as active by both variants a = MutableWrapper(1.0) @assert !EnzymeRules.inactive_type(typeof(a)) + a.x = 1.0 make_zero!(a, Val(false)) @assert a == MutableWrapper(0.0) a.x = 1.0 + make_zero!(a; runtime_inactive=Val(false)) + @assert a == MutableWrapper(0.0) + a.x = 1.0 make_zero!(a, Val(true)) @assert a == MutableWrapper(0.0) + a.x = 1.0 + make_zero!(a; runtime_inactive=Val(true)) + @assert a == MutableWrapper(0.0) # mark MutableWrapper as inactive @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true @@ -860,22 +888,34 @@ function test_make_zero!() a.x = 1.0 @invokelatest make_zero!(a, Val(false)) @test a == MutableWrapper(0.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive=Val(false)) + @test a == MutableWrapper(0.0) # runtime_inactive == true => redefined inactive_type should take effect # MutableWrapper considered inactive and won't be zeroed a.x = 1.0 @invokelatest make_zero!(a, Val(true)) @test a == MutableWrapper(1.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive=Val(true)) + @test a == MutableWrapper(1.0) # mark MutableWrapper as active again @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false # verify that MutableWrapper is seen as active by both variants a.x = 1.0 + @invokelatest make_zero!(a, Val(true)) + @test a == MutableWrapper(0.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive=Val(true)) + @test a == MutableWrapper(0.0) + a.x = 1.0 @invokelatest make_zero!(a, Val(false)) @test a == MutableWrapper(0.0) a.x = 1.0 - @invokelatest make_zero!(a, Val(true)) + @invokelatest make_zero!(a; runtime_inactive=Val(false)) @test a == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin From fac2f0b9bcdbbe36c3be5e30785602aa9c7dcb24 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 13:16:00 -0800 Subject: [PATCH 08/17] Use simple assert for Nout check This is an internal sanity check that shouldn't need to throw an error back to the user --- src/typeutils/recursive_maps.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index 15f5d9dd3e..fc16319111 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -206,7 +206,7 @@ const YS{Nout,T} = Union{Val{Nout},NTuple{Nout,T}} function recursive_map( f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config::InactiveConfig=InactiveConfig() ) where {F,Nout,Nin,T} - check_nout(ys) + @assert (Nout == 1) || (Nout == 2) newys = if isinactivetype(T, config) recursive_map_inactive(nothing, ys, xs, config) elseif isvectortype(T) || isbitstype(T) @@ -226,7 +226,7 @@ function recursive_map( config::InactiveConfig=InactiveConfig(), ) where {F,Nout,Nin,T} # determine whether to continue recursion, copy/share, or retrieve from cache - check_nout(ys) + @assert (Nout == 1) || (Nout == 2) newys = if isinactivetype(T, config) recursive_map_inactive(seen, ys, xs, config) elseif isbitstype(T) # no object identity to to track in this branch @@ -581,19 +581,13 @@ end else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented cache = seen[x1]::NTuple{(Nout + Nin - 1),T} cachedtail = cache[(Nout+1):end] - check_identical(cachedtail, xtail) # check compatible layout + check_identical(cachedtail, xtail) # check compatible structure cache[1:Nout] end return newys::NTuple{Nout,T} end ## argument validation -@inline function check_nout(::YS{Nout}) where {Nout} - if Nout > 2 - throw_nout() - end -end - Base.@propagate_inbounds function check_initialized(x, i, initialized=true) if isinitialized(x, i) != initialized throw_initialized() # TODO: hit this when VectorSpace implemented @@ -644,7 +638,7 @@ end end @noinline function throw_identical() - msg = "recursive_map(!) called on objects whose layout don't match" + msg = "recursive_map(!) called on objects whose structure don't match" throw(ArgumentError(msg)) end From 43d182e15d2407dfd0f094d58ebb62feb2fb2954 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 18:16:48 -0800 Subject: [PATCH 09/17] Reorganize test module to please the formatter Runic says that a module should be indented unless it's the only top-level element in a file, so let's make sure it is Using `# !format: off` rather than `# runic: off` around manually aligned code for compatibility with whatever autoformatter may be active active in people's editors --- test/recursive_maps.jl | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 787b0f58ec..8fe5700bac 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -191,6 +191,7 @@ const inactivetup = (inactivebits, "a", MutableEmpty()) const inactivearr = [inactivetup] const wrappers = [ + #! format: off (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true, bitsonly=false), (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true, bitsonly=false), @@ -238,6 +239,7 @@ const wrappers = [ # GPUArrays extension (name="JLArray{X}", f=(x -> JLArray([x])), N=1, mutable=true, typed=true, bitsonly=true), (name="JLArray{promote_type(X,Y)}", f=((x, y) -> JLArray([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=true), + #! format: on ] @static if VERSION < v"1.11-" @@ -245,10 +247,12 @@ else _memory(x::Vector) = Memory{eltype(x)}(x) push!( wrappers, + #! format: off (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true, bitsonly=false), (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false, bitsonly=false), (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=false), (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false, bitsonly=false), + #! format: on ) end @@ -435,11 +439,13 @@ function test_make_zero() end end @testset "copy_if_inactive $value" for (value, args, kwargs) in [ - ("unspecified", (), (;)), + #! format: off + ("unspecified", (), (;)), ("= false", (Val(false),), (;)), - ("= false (kwarg)", (), (; copy_if_inactive=Val(false))), - ("= true", (Val(true),), (;)), - ("= true (kwarg)", (), (; copy_if_inactive=Val(true))), + ("= false (kwarg)", (), (; copy_if_inactive=Val(false))), + ("= true", (Val(true),), (;)), + ("= true (kwarg)", (), (; copy_if_inactive=Val(true))), + #! format: on ] a = [1.0] w = Any[a, inactivearr, inactivearr] @@ -973,7 +979,11 @@ function test_make_zero!() return nothing end -end # module RecursiveMapTests +# because this is wrapped in a module, we should only run a single top-level testset +# otherwise a failed test in the first set will prevent the second from running +@testset "recursive maps" begin + @testset "make_zero" test_make_zero() + @testset "make_zero!" test_make_zero!() +end -@testset "make_zero" RecursiveMapTests.test_make_zero() -@testset "make_zero!" RecursiveMapTests.test_make_zero!() +end # module RecursiveMapTests From 1ad3882f5cb2b21ce772115fb913dd0346bbb8ac Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 19:04:14 -0800 Subject: [PATCH 10/17] Apply suggestions from Runic.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/EnzymeStaticArraysExt.jl | 2 +- src/typeutils/recursive_add.jl | 6 ++-- src/typeutils/recursive_maps.jl | 52 ++++++++++++++++----------------- test/runtests.jl | 6 ++-- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 14d18a2835..8c15f21cb7 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -35,7 +35,7 @@ end # SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct, # but in case their dedicated `zero` and `fill!` methods are more efficient than # `make_zero(!)`s recursion, we opt into treating them as leaves. -@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S,T}}) where {S,T} +@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S, T}}) where {S, T} return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T) end diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 482b6fc503..368eca3d09 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -19,7 +19,7 @@ types, such that `zi = xi + f(yi)` applies to differentiable values, while `zi = to non-differentiable values. If a custom callable is passed, it is combined with the default, as `recursive_add` is not generally capable of traversing inactive objects. """ -function recursive_add(x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const) where {T,F,L} +function recursive_add(x::T, y::T, f::F = identity, forcelhs::L = guaranteed_const) where {T, F, L} function addf(xi::S, yi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((xi + f(yi))::S,) @@ -55,7 +55,7 @@ be passed to `accumulate_seen` to enzure consistency. """ function accumulate_seen! end -function accumulate_seen!(f::F, seen::IdDict, args::Vararg{Any,M}; kws...) where {F,M} +function accumulate_seen!(f::F, seen::IdDict, args::Vararg{Any, M}; kws...) where {F, M} accumulate_seen!(f, seen, RecursiveMaps.make_zero!_config(args...; kws...)) return nothing end @@ -68,7 +68,7 @@ function accumulate_seen!(f::F, seen::IdDict, config::RecursiveMaps.InactiveConf return nothing end -function _accumulate_seen_item!(f::F, k::T, v::T, config, cachedconfig) where {F,T} +function _accumulate_seen_item!(f::F, k::T, v::T, config, cachedconfig) where {F, T} function addf!!(ki::S, vi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((ki .+ f.(vi))::S,) diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index fc16319111..d97090869c 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -88,8 +88,8 @@ end end @inline EnzymeCore.isscalartype(::Type) = false -@inline EnzymeCore.isscalartype(::Type{T}) where {T<:AbstractFloat} = isconcretetype(T) -@inline function EnzymeCore.isscalartype(::Type{Complex{T}}) where {T<:AbstractFloat} +@inline EnzymeCore.isscalartype(::Type{T}) where {T <: AbstractFloat} = isconcretetype(T) +@inline function EnzymeCore.isscalartype(::Type{Complex{T}}) where {T <: AbstractFloat} return isconcretetype(T) end @@ -301,8 +301,8 @@ end end @generated function recursive_map_mutable_inner!( - seen, f::F, newys::NTuple{Nout,T}, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + seen, f::F, newys::NTuple{Nout, T}, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} return quote @inline Base.Cartesian.@nexprs $(fieldcount(T)) i -> @inbounds begin @@ -313,8 +313,8 @@ end end @inline function recursive_map_immutable( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} @assert !ismutabletype(T) nf = fieldcount(T) if nf == 0 # nothing to do (also no known way to hit this branch) @@ -355,8 +355,8 @@ end end @generated function recursive_map_new( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} # direct construction of fully initialized non-cyclic structs nf = fieldcount(T) return quote @@ -385,8 +385,8 @@ Base.@propagate_inbounds function recursive_map_item!( end Base.@propagate_inbounds function recursive_map_item( - i, seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + i, seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} # recurse into the xs and apply recursive_map to items with index i xs_i = getitems(xs, i) newys_i = if hasvalues(ys) && isinitialized(first(ys), i) @@ -455,7 +455,7 @@ end else x1 end - return ntuple(_ -> (@inline; y), Val(Nout))::NTuple{Nout,T} + return ntuple(_ -> (@inline; y), Val(Nout))::NTuple{Nout, T} end ### recursive_map!: fully in-place wrapper around recursive_map @@ -483,8 +483,8 @@ indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. function recursive_map! end function recursive_map!( - f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, config::InactiveConfig=InactiveConfig() -) where {F,Nout,Nin,T} + f!!::F, ys::NTuple{Nout, T}, xs::NTuple{Nin, T}, config::InactiveConfig = InactiveConfig() + ) where {F, Nout, Nin, T} check_nonactive(T, config) newys = recursive_map(f!!, ys, xs, config) @assert newys === ys @@ -492,12 +492,12 @@ function recursive_map!( end function recursive_map!( - seen::Union{Nothing,IdDict}, - f!!::F, - ys::NTuple{Nout,T}, - xs::NTuple{Nin,T}, - config::InactiveConfig=InactiveConfig(), -) where {F,Nout,Nin,T} + seen::Union{Nothing, IdDict}, + f!!::F, + ys::NTuple{Nout, T}, + xs::NTuple{Nin, T}, + config::InactiveConfig = InactiveConfig(), + ) where {F, Nout, Nin, T} check_nonactive(T, config) newys = recursive_map(seen, f!!, ys, xs, config) @assert newys === ys @@ -505,7 +505,7 @@ function recursive_map!( end ### recursive_map helpers -@generated function new_structvs(::Type{T}, fields::NTuple{N,Vector{Any}}, nfields_) where {T,N} +@generated function new_structvs(::Type{T}, fields::NTuple{N, Vector{Any}}, nfields_) where {T, N} return quote @inline return Base.@ntuple $N j -> begin @@ -515,7 +515,7 @@ end end @inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T -@inline _similar(x::T) where {T<:DenseArray} = similar(x)::T +@inline _similar(x::T) where {T <: DenseArray} = similar(x)::T Base.@propagate_inbounds isinitialized(x, i) = isdefined(x, i) Base.@propagate_inbounds isinitialized(x::DenseArray, i) = isassigned(x, i) Base.@propagate_inbounds getitem(x, i) = getfield(x, i) @@ -560,7 +560,7 @@ end @inline shouldcache(::IdDict, ::Type{T}) where {T} = iscachedtype(T) @inline shouldcache(::Nothing, ::Type{T}) where {T} = false -@inline function maybecache!(seen, newys::NTuple{Nout,T}, (x1, xtail...)::NTuple{Nin,T}) where {Nout,Nin,T} +@inline function maybecache!(seen, newys::NTuple{Nout, T}, (x1, xtail...)::NTuple{Nin, T}) where {Nout, Nin, T} if shouldcache(seen, T) seen[x1] = if (Nout == 1) && (Nin == 1) only(newys) @@ -648,7 +648,7 @@ end end ### EnzymeCore.make_zero(!) implementation -@inline function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}; kws...) where {T,M} +@inline function EnzymeCore.make_zero(prev::T, args::Vararg{Any, M}; kws...) where {T, M} config = make_zero_config(args...; kws...) new = if iszero(M) && isempty(kws) && !isinactivetype(T, config) && isvectortype(T) # fallback # isinactivetype precedes over isvectortype for consistency with recursive handler @@ -659,7 +659,7 @@ end return new::T end -@inline function EnzymeCore.make_zero!(val::T, args::Vararg{Any,M}; kws...) where {T,M} +@inline function EnzymeCore.make_zero!(val::T, args::Vararg{Any, M}; kws...) where {T, M} @assert !isscalartype(T) # not appropriate for in-place handler if iszero(M) && isempty(kws) && !isinactivetype(T, make_zero!_config()) && isvectortype(T) # fallback # isinactivetype precedes over isvectortype for consistency with recursive handler @@ -708,8 +708,8 @@ end # alternative entry point for passing custom IdDict @inline function EnzymeCore.make_zero( - ::Type{T}, seen::IdDict, prev::T, args::Vararg{Any,M}; kws... -) where {T,M} + ::Type{T}, seen::IdDict, prev::T, args::Vararg{Any, M}; kws... + ) where {T, M} config = make_zero_config(args...; kws...) news = recursive_map(seen, _make_zero!!, Val(1), (prev,), config) return only(news)::T diff --git a/test/runtests.jl b/test/runtests.jl index dc8e583788..4b55ab5882 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -553,8 +553,8 @@ end @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) function reverse_holomorphic_array_tests( - f, val, dval_expected; val_expected=val, ret=Active, mapf=true - ) + f, val, dval_expected; val_expected = val, ret = Active, mapf = true + ) vals = ComplexF64[val] dvals = ComplexF64[zero(val)] autodiff(ReverseHolomorphic, f, ret, Duplicated(vals, dvals)) @@ -592,7 +592,7 @@ end nothing end @testset "setinact" reverse_holomorphic_array_tests( - setinact, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + setinact, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Const, mapf = false ) function setinact2(z) From 316cc3124de72e0900e9359adf0938234b02057b Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 19:15:57 -0800 Subject: [PATCH 11/17] More obliging the formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/recursive_maps.jl | 82 +++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 8fe5700bac..36c16d4fc3 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -20,7 +20,7 @@ setx!(w, x) = (w[begin] = x) sety!(w, y) = (w[end] = y) # non-isbits MArray doesn't support setindex!, so requires a little hack -function setx!(w::MArray{S,T}, x) where {S,T} +function setx!(w::MArray{S, T}, x) where {S, T} if isbitstype(T) w[begin] = x else @@ -29,7 +29,7 @@ function setx!(w::MArray{S,T}, x) where {S,T} return x end -function sety!(w::MArray{S,T}, y) where {S,T} +function sety!(w::MArray{S, T}, y) where {S, T} if isbitstype(T) w[end] = y else @@ -153,7 +153,7 @@ function Enzyme.EnzymeCore.isvectortype(::Type{CustomVector{T}}) where {T} return Enzyme.EnzymeCore.isscalartype(T) end -function Enzyme.EnzymeCore.make_zero(prev::CV) where {CV<:CustomVector{<:AbstractFloat}} +function Enzyme.EnzymeCore.make_zero(prev::CV) where {CV <: CustomVector{<:AbstractFloat}} @info "make_zero(::CustomVector)" return CustomVector(zero(prev.data))::CV end @@ -268,8 +268,8 @@ function test_make_zero() end @testset "nested types" begin @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( - w -> (w.N == 1), wrappers - ) + w -> (w.N == 1), wrappers + ) (!wrapper.bitsonly || isbitstype(T)) || continue x = oneunit(T) w = wrapper.f(x) @@ -280,8 +280,8 @@ function test_make_zero() @test getx(w) === x # no mutation of original @test x == oneunit(T) # no mutation of original (relevant for BigFloat) @testset "doubly included in $(dualwrapper.name)" for dualwrapper in filter( - w -> (w.N == 2), wrappers - ) + w -> (w.N == 2), wrappers + ) (!dualwrapper.bitsonly || isbitstype(T)) || continue w_inner = wrapper.f(x) if !dualwrapper.bitsonly || isbits(w_inner) @@ -315,8 +315,8 @@ function test_make_zero() # some code paths can only be hit with three layers of wrapping: # mutable(immutable(mutable(scalar))) @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( - w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers - ) + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) w_inner = wrapper.f(x) d_middle = dualwrapper.f(w_inner, w_inner) w_outer = outerwrapper.f(d_middle) @@ -341,9 +341,9 @@ function test_make_zero() @testset "in $(wrapper.name)" for wrapper in wrappers if wrapper.N == 1 for (inactive, condition) in [ - (inactivebits, true), - (inactivearr, !wrapper.bitsonly), - ] + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] condition || continue w = wrapper.f(inactive) w_makez = make_zero(w) @@ -363,9 +363,9 @@ function test_make_zero() end @testset "mixed" begin for (inactive, mixed, condition) in [ - (inactivebits, (1.0, inactivebits), true), - (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), - ] + (inactivebits, (1.0, inactivebits), true), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] condition || continue w = wrapper.f(mixed) w_makez = make_zero(w) @@ -384,9 +384,9 @@ function test_make_zero() else # wrapper.N == 2 @testset "multiple references" begin for (inactive, condition) in [ - (inactivebits, true), - (inactivearr, !wrapper.bitsonly), - ] + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] condition || continue w = wrapper.f(inactive, inactive) w_makez = make_zero(w) @@ -459,7 +459,7 @@ function test_make_zero() @test w[2] === w[3] # no mutation of original @test w[2] === inactivearr # no mutation of original @test inactivearr[1] === inactivetup # no mutation of original - if (args == (Val(true),)) || (kwargs == (; copy_if_inactive=Val(true))) + if (args == (Val(true),)) || (kwargs == (; copy_if_inactive = Val(true))) @test typeof(w_makez[2]) === typeof(inactivearr) # correct type @test w_makez[2] == inactivearr # correct value @test w_makez[2][1] !== inactivetup # correct identity @@ -508,8 +508,8 @@ function test_make_zero() end @testset "circular references" begin @testset "$(wrapper.name)" for wrapper in filter( - w -> (w.mutable && (w.typed in (:partial, false))), wrappers - ) + w -> (w.mutable && (w.typed in (:partial, false))), wrappers + ) a = [1.0] if wrapper.N == 1 w = wrapper.f(nothing) @@ -577,7 +577,7 @@ function test_make_zero() # runtime_inactive == false => redefined inactive_type should have no effect a_makez = @invokelatest make_zero(a, Val(false), Val(false)) @test a_makez == MutableWrapper(0.0) - a_makez = @invokelatest make_zero(a; runtime_inactive=Val(false)) + a_makez = @invokelatest make_zero(a; runtime_inactive = Val(false)) @test a_makez == MutableWrapper(0.0) # runtime_inactive == true => redefined inactive_type should take effect @@ -648,7 +648,7 @@ function test_make_zero() end @testset "mutable struct w inactive/const active/active/mutable/undefined" begin a = [1.0] - incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete = MutableIncomplete("a", #=const=# 1.0, 1.0, a) incomplete_makez = make_zero(incomplete) @test typeof(incomplete_makez) === typeof(incomplete) # correct type @test typeof(incomplete_makez.w) === typeof(a) # correct type @@ -670,8 +670,8 @@ end function test_make_zero!() @testset "nested types" begin @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( - w -> (w.N == 1), wrappers - ) + w -> (w.N == 1), wrappers + ) (!wrapper.bitsonly || isbitstype(T)) || continue x = oneunit(T) if wrapper.mutable @@ -682,8 +682,8 @@ function test_make_zero!() @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) end @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( - filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) - ) + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) (!dualwrapper.bitsonly || isbitstype(T)) || continue w_inner = wrapper.f(x) if !dualwrapper.bitsonly || isbits(w_inner) @@ -715,8 +715,8 @@ function test_make_zero!() # some code paths can only be hit with three layers of wrapping: # mutable(immutable(mutable(scalar))) @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( - w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers - ) + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) w_inner = wrapper.f(x) d_middle = dualwrapper.f(w_inner, w_inner) w_outer = outerwrapper.f(d_middle) @@ -752,9 +752,9 @@ function test_make_zero!() end @testset "mixed" begin for (inactive, mixed, condition) in [ - (inactivebits, (1.0, inactivebits), wrapper.mutable), - (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), - ] + (inactivebits, (1.0, inactivebits), wrapper.mutable), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] condition || continue w = wrapper.f(mixed) make_zero!(w) @@ -769,9 +769,9 @@ function test_make_zero!() else # wrapper.N == 2 @testset "multiple references" begin for (inactive, condition) in [ - (inactivebits, true), - (inactivearr, !wrapper.bitsonly), - ] + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] condition || continue w = wrapper.f(inactive, inactive) make_zero!(w) @@ -832,8 +832,8 @@ function test_make_zero!() end @testset "circular references" begin @testset "$(wrapper.name)" for wrapper in filter( - w -> (w.mutable && (w.typed in (:partial, false))), wrappers - ) + w -> (w.mutable && (w.typed in (:partial, false))), wrappers + ) a = [1.0] if wrapper.N == 1 w = wrapper.f(nothing) @@ -895,7 +895,7 @@ function test_make_zero!() @invokelatest make_zero!(a, Val(false)) @test a == MutableWrapper(0.0) a.x = 1.0 - @invokelatest make_zero!(a; runtime_inactive=Val(false)) + @invokelatest make_zero!(a; runtime_inactive = Val(false)) @test a == MutableWrapper(0.0) # runtime_inactive == true => redefined inactive_type should take effect @@ -904,7 +904,7 @@ function test_make_zero!() @invokelatest make_zero!(a, Val(true)) @test a == MutableWrapper(1.0) a.x = 1.0 - @invokelatest make_zero!(a; runtime_inactive=Val(true)) + @invokelatest make_zero!(a; runtime_inactive = Val(true)) @test a == MutableWrapper(1.0) # mark MutableWrapper as active again @@ -947,7 +947,7 @@ function test_make_zero!() end @testset "mutable struct w inactive/const active/active/mutable/undefined" begin a = [1.0] - incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete = MutableIncomplete("a", #=const=# 1.0, 1.0, a) make_zero!(incomplete) @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined @test incomplete.w === a # preserved identity @@ -981,7 +981,7 @@ end # because this is wrapped in a module, we should only run a single top-level testset # otherwise a failed test in the first set will prevent the second from running -@testset "recursive maps" begin +@testset "recursive maps" begin @testset "make_zero" test_make_zero() @testset "make_zero!" test_make_zero!() end From 5f934cbe6c9f27f06c37206cb62db201f7f2163a Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 19:20:46 -0800 Subject: [PATCH 12/17] Tweak alignment and format:off/on --- test/recursive_maps.jl | 120 +++++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 58 deletions(-) diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 36c16d4fc3..5bfa98029e 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -190,70 +190,70 @@ const inactivebits = (1, Empty()) const inactivetup = (inactivebits, "a", MutableEmpty()) const inactivearr = [inactivetup] +#! format: off const wrappers = [ - #! format: off - (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true, bitsonly=false), - (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), - (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true, bitsonly=false), + (name = "Tuple{X}", f = tuple, N = 1, mutable = false, typed = true, bitsonly = false), + (name = "@NamedTuple{x::X}", f = (NamedTuple{(:x,)} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), + (name = "struct{X}", f = Wrapper, N = 1, mutable = false, typed = true, bitsonly = false), - (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false, bitsonly=false), - (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false, bitsonly=false), + (name = "@NamedTuple{x}", f = (@NamedTuple{x} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), + (name = "struct{Any}", f = Wrapper{Any}, N = 1, mutable = false, typed = false, bitsonly = false), - (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true, bitsonly=false), - (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true, bitsonly=false), - (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true, bitsonly=false), + (name = "Array{X}", f = (x -> [x]), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Base.RefValue{X}", f = Ref, N = 1, mutable = true, typed = true, bitsonly = false), + (name = "mutable struct{X}", f = MutableWrapper, N = 1, mutable = true, typed = true, bitsonly = false), - (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false, bitsonly=false), - (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false, bitsonly=false), - (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false, bitsonly=false), - (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false, bitsonly=false), + (name = "Array{Any}", f = (x -> Any[x]), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Base.RefValue{Any}", f = Ref{Any}, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Core.Box", f = Core.Box, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "mutable struct{Any}", f = MutableWrapper{Any}, N = 1, mutable = true, typed = false, bitsonly = false), - (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true, bitsonly=false), - (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true, bitsonly=false), - (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true, bitsonly=false), + (name = "Tuple{X,Y}", f = tuple, N = 2, mutable = false, typed = true, bitsonly = false), + (name = "@NamedTuple{x::X,y::Y}", f = (NamedTuple{(:x, :y)} ∘ tuple), N = 2, mutable = false, typed = true, bitsonly = false), + (name = "struct{X,Y}", f = DualWrapper, N = 2, mutable = false, typed = true, bitsonly = false), - (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial, bitsonly=false), - (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial, bitsonly=false), + (name = "@NamedTuple{x,y::Y}", f = ((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N = 2, mutable = false, typed = :partial, bitsonly = false), + (name = "struct{Any,Y}", f = DualWrapper{Any}, N = 2, mutable = false, typed = :partial, bitsonly = false), - (name="@NamedTuple{x,y}", f=(@NamedTuple{x,y} ∘ tuple), N=2, mutable=false, typed=false, bitsonly=false), - (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false, bitsonly=false), + (name = "@NamedTuple{x,y}", f = (@NamedTuple{x,y} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), + (name = "struct{Any}", f = DualWrapper{Any,Any}, N = 2, mutable = false, typed = false, bitsonly = false), - (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true, bitsonly=false), + (name = "mutable struct{X,Y}", f = MutableDualWrapper, N = 2, mutable = true, typed = true, bitsonly = false), - (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted, bitsonly=false), - (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial, bitsonly=false), + (name = "Array{promote_type(X,Y)}", f = ((x, y) -> [x, y]), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "mutable struct{Any,Y}", f = MutableDualWrapper{Any}, N = 2, mutable = true, typed = :partial, bitsonly = false), - (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false, bitsonly=false), - (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false, bitsonly=false), + (name = "Array{Any}", f = ((x, y) -> Any[x, y]), N = 2, mutable = true, typed = false, bitsonly = false), + (name = "mutable struct{Any,Any}", f = MutableDualWrapper{Any,Any}, N = 2, mutable = true, typed = false, bitsonly = false), # StaticArrays extension - (name="SVector{1,X}", f=(SVector{1} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), - (name="SVector{1,Any}", f=(SVector{1,Any} ∘ tuple), N=1, mutable=false, typed=false, bitsonly=false), - (name="MVector{1,X}", f=(MVector{1} ∘ tuple), N=1, mutable=true, typed=true, bitsonly=false), - (name="MVector{1,Any}", f=(MVector{1,Any} ∘ tuple), N=1, mutable=true, typed=false, bitsonly=false), - (name="SVector{2,promote_type(X,Y)}", f=(SVector{2} ∘ tuple), N=2, mutable=false, typed=:promoted, bitsonly=false), - (name="SVector{2,Any}", f=(SVector{2,Any} ∘ tuple), N=2, mutable=false, typed=false, bitsonly=false), - (name="MVector{2,promote_type(X,Y)}", f=(MVector{2} ∘ tuple), N=2, mutable=true, typed=:promoted, bitsonly=false), - (name="MVector{2,Any}", f=(MVector{2,Any} ∘ tuple), N=2, mutable=true, typed=false, bitsonly=false), + (name = "SVector{1,X}", f = (SVector{1} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), + (name = "SVector{1,Any}", f = (SVector{1,Any} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), + (name = "MVector{1,X}", f = (MVector{1} ∘ tuple), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "MVector{1,Any}", f = (MVector{1,Any} ∘ tuple), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "SVector{2,promote_type(X,Y)}", f = (SVector{2} ∘ tuple), N = 2, mutable = false, typed = :promoted, bitsonly = false), + (name = "SVector{2,Any}", f = (SVector{2,Any} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), + (name = "MVector{2,promote_type(X,Y)}", f = (MVector{2} ∘ tuple), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "MVector{2,Any}", f = (MVector{2,Any} ∘ tuple), N = 2, mutable = true, typed = false, bitsonly = false), # GPUArrays extension - (name="JLArray{X}", f=(x -> JLArray([x])), N=1, mutable=true, typed=true, bitsonly=true), - (name="JLArray{promote_type(X,Y)}", f=((x, y) -> JLArray([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=true), - #! format: on + (name = "JLArray{X}", f = (x -> JLArray([x])), N = 1, mutable = true, typed = true, bitsonly = true), + (name = "JLArray{promote_type(X,Y)}", f = ((x, y) -> JLArray([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = true), ] +#! format: on @static if VERSION < v"1.11-" else _memory(x::Vector) = Memory{eltype(x)}(x) +#! format: off push!( wrappers, - #! format: off - (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true, bitsonly=false), - (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false, bitsonly=false), - (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=false), - (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false, bitsonly=false), - #! format: on + (name = "Memory{X}", f = (x -> _memory([x])), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Memory{Any}", f = (x -> _memory(Any[x])), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Memory{promote_type(X,Y)}", f = ((x, y) -> _memory([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "Memory{Any}", f = ((x, y) -> _memory(Any[x, y])), N = 2, mutable = true, typed = false, bitsonly = false), ) +#! format: on end function test_make_zero() @@ -438,15 +438,14 @@ function test_make_zero() end end end + #! format: off @testset "copy_if_inactive $value" for (value, args, kwargs) in [ - #! format: off - ("unspecified", (), (;)), - ("= false", (Val(false),), (;)), - ("= false (kwarg)", (), (; copy_if_inactive=Val(false))), - ("= true", (Val(true),), (;)), - ("= true (kwarg)", (), (; copy_if_inactive=Val(true))), - #! format: on - ] + ("unspecified", (), (;)), + ("= false", (Val(false),), (;)), + ("= false (kwarg)", (), (; copy_if_inactive = Val(false))), + ("= true", (Val(true),), (;)), + ("= true (kwarg)", (), (; copy_if_inactive = Val(true))), + ] a = [1.0] w = Any[a, inactivearr, inactivearr] w_makez = make_zero(w, args...; kwargs...) @@ -467,6 +466,7 @@ function test_make_zero() @test w_makez[2] === inactivearr # correct value/type/identity end end + #! format: on end @testset "heterogeneous containers" begin scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) @@ -475,11 +475,12 @@ function test_make_zero() items = (inactivetup..., scalars..., wraps..., mwraps...) itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) labels = Symbol.("i" .* string.(1:length(items))) + #! format: off @testset "$name" for (name, c, cz) in [ - ("Tuple", Tuple(items), Tuple(itemsz)), - ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), - ("Array", collect(items), collect(itemsz)), - ] + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] c_makez = make_zero(c) @test typeof(c_makez) === typeof(c) # correct type @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type @@ -488,6 +489,7 @@ function test_make_zero() @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original end + #! format: on end @testset "heterogeneous float arrays" begin b1r, b2r = big"1.0", big"2.0" @@ -802,15 +804,17 @@ function test_make_zero!() items = (inactivetup..., mwraps...) itemsz = (inactivetup..., mwrapsz...) labels = Symbol.("i" .* string.(1:length(items))) + #! format: off @testset "$name" for (name, c, cz) in [ - ("Tuple", Tuple(items), Tuple(itemsz)), - ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), - ("Array", collect(items), collect(itemsz)), - ] + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] make_zero!(c) @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities @test c == cz # correct value end + #! format: on end @testset "heterogeneous float arrays" begin b1r, b2r = big"1.0", big"2.0" From bdb98364bdc07fb0ac6377b43e16a5c6958f35a8 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 19:24:56 -0800 Subject: [PATCH 13/17] Follow non-alignment related rules in format: off blocks --- test/recursive_maps.jl | 76 +++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 5bfa98029e..1974076160 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -192,53 +192,53 @@ const inactivearr = [inactivetup] #! format: off const wrappers = [ - (name = "Tuple{X}", f = tuple, N = 1, mutable = false, typed = true, bitsonly = false), - (name = "@NamedTuple{x::X}", f = (NamedTuple{(:x,)} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), - (name = "struct{X}", f = Wrapper, N = 1, mutable = false, typed = true, bitsonly = false), + (name = "Tuple{X}", f = tuple, N = 1, mutable = false, typed = true, bitsonly = false), + (name = "@NamedTuple{x::X}", f = (NamedTuple{(:x,)} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), + (name = "struct{X}", f = Wrapper, N = 1, mutable = false, typed = true, bitsonly = false), - (name = "@NamedTuple{x}", f = (@NamedTuple{x} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), - (name = "struct{Any}", f = Wrapper{Any}, N = 1, mutable = false, typed = false, bitsonly = false), + (name = "@NamedTuple{x}", f = (@NamedTuple{x} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), + (name = "struct{Any}", f = Wrapper{Any}, N = 1, mutable = false, typed = false, bitsonly = false), - (name = "Array{X}", f = (x -> [x]), N = 1, mutable = true, typed = true, bitsonly = false), - (name = "Base.RefValue{X}", f = Ref, N = 1, mutable = true, typed = true, bitsonly = false), - (name = "mutable struct{X}", f = MutableWrapper, N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Array{X}", f = (x -> [x]), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Base.RefValue{X}", f = Ref, N = 1, mutable = true, typed = true, bitsonly = false), + (name = "mutable struct{X}", f = MutableWrapper, N = 1, mutable = true, typed = true, bitsonly = false), - (name = "Array{Any}", f = (x -> Any[x]), N = 1, mutable = true, typed = false, bitsonly = false), - (name = "Base.RefValue{Any}", f = Ref{Any}, N = 1, mutable = true, typed = false, bitsonly = false), - (name = "Core.Box", f = Core.Box, N = 1, mutable = true, typed = false, bitsonly = false), - (name = "mutable struct{Any}", f = MutableWrapper{Any}, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Array{Any}", f = (x -> Any[x]), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Base.RefValue{Any}", f = Ref{Any}, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Core.Box", f = Core.Box, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "mutable struct{Any}", f = MutableWrapper{Any}, N = 1, mutable = true, typed = false, bitsonly = false), - (name = "Tuple{X,Y}", f = tuple, N = 2, mutable = false, typed = true, bitsonly = false), - (name = "@NamedTuple{x::X,y::Y}", f = (NamedTuple{(:x, :y)} ∘ tuple), N = 2, mutable = false, typed = true, bitsonly = false), - (name = "struct{X,Y}", f = DualWrapper, N = 2, mutable = false, typed = true, bitsonly = false), + (name = "Tuple{X, Y}", f = tuple, N = 2, mutable = false, typed = true, bitsonly = false), + (name = "@NamedTuple{x::X, y::Y}", f = (NamedTuple{(:x, :y)} ∘ tuple), N = 2, mutable = false, typed = true, bitsonly = false), + (name = "struct{X, Y}", f = DualWrapper, N = 2, mutable = false, typed = true, bitsonly = false), - (name = "@NamedTuple{x,y::Y}", f = ((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N = 2, mutable = false, typed = :partial, bitsonly = false), - (name = "struct{Any,Y}", f = DualWrapper{Any}, N = 2, mutable = false, typed = :partial, bitsonly = false), + (name = "@NamedTuple{x, y::Y}", f = ((x, y) -> @NamedTuple{x, y::typeof(y)}((x, y))), N = 2, mutable = false, typed = :partial, bitsonly = false), + (name = "struct{Any, Y}", f = DualWrapper{Any}, N = 2, mutable = false, typed = :partial, bitsonly = false), - (name = "@NamedTuple{x,y}", f = (@NamedTuple{x,y} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), - (name = "struct{Any}", f = DualWrapper{Any,Any}, N = 2, mutable = false, typed = false, bitsonly = false), + (name = "@NamedTuple{x, y}", f = (@NamedTuple{x, y} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), + (name = "struct{Any}", f = DualWrapper{Any, Any}, N = 2, mutable = false, typed = false, bitsonly = false), - (name = "mutable struct{X,Y}", f = MutableDualWrapper, N = 2, mutable = true, typed = true, bitsonly = false), + (name = "mutable struct{X, Y}", f = MutableDualWrapper, N = 2, mutable = true, typed = true, bitsonly = false), - (name = "Array{promote_type(X,Y)}", f = ((x, y) -> [x, y]), N = 2, mutable = true, typed = :promoted, bitsonly = false), - (name = "mutable struct{Any,Y}", f = MutableDualWrapper{Any}, N = 2, mutable = true, typed = :partial, bitsonly = false), + (name = "Array{promote_type(X, Y)}", f = ((x, y) -> [x, y]), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "mutable struct{Any, Y}", f = MutableDualWrapper{Any}, N = 2, mutable = true, typed = :partial, bitsonly = false), - (name = "Array{Any}", f = ((x, y) -> Any[x, y]), N = 2, mutable = true, typed = false, bitsonly = false), - (name = "mutable struct{Any,Any}", f = MutableDualWrapper{Any,Any}, N = 2, mutable = true, typed = false, bitsonly = false), + (name = "Array{Any}", f = ((x, y) -> Any[x, y]), N = 2, mutable = true, typed = false, bitsonly = false), + (name = "mutable struct{Any, Any}", f = MutableDualWrapper{Any,Any}, N = 2, mutable = true, typed = false, bitsonly = false), # StaticArrays extension - (name = "SVector{1,X}", f = (SVector{1} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), - (name = "SVector{1,Any}", f = (SVector{1,Any} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), - (name = "MVector{1,X}", f = (MVector{1} ∘ tuple), N = 1, mutable = true, typed = true, bitsonly = false), - (name = "MVector{1,Any}", f = (MVector{1,Any} ∘ tuple), N = 1, mutable = true, typed = false, bitsonly = false), - (name = "SVector{2,promote_type(X,Y)}", f = (SVector{2} ∘ tuple), N = 2, mutable = false, typed = :promoted, bitsonly = false), - (name = "SVector{2,Any}", f = (SVector{2,Any} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), - (name = "MVector{2,promote_type(X,Y)}", f = (MVector{2} ∘ tuple), N = 2, mutable = true, typed = :promoted, bitsonly = false), - (name = "MVector{2,Any}", f = (MVector{2,Any} ∘ tuple), N = 2, mutable = true, typed = false, bitsonly = false), + (name = "SVector{1, X}", f = (SVector{1} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), + (name = "SVector{1, Any}", f = (SVector{1, Any} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), + (name = "MVector{1, X}", f = (MVector{1} ∘ tuple), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "MVector{1, Any}", f = (MVector{1, Any} ∘ tuple), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "SVector{2, promote_type(X, Y)}", f = (SVector{2} ∘ tuple), N = 2, mutable = false, typed = :promoted, bitsonly = false), + (name = "SVector{2, Any}", f = (SVector{2, Any} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), + (name = "MVector{2, promote_type(X, Y)}", f = (MVector{2} ∘ tuple), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "MVector{2, Any}", f = (MVector{2, Any} ∘ tuple), N = 2, mutable = true, typed = false, bitsonly = false), # GPUArrays extension - (name = "JLArray{X}", f = (x -> JLArray([x])), N = 1, mutable = true, typed = true, bitsonly = true), - (name = "JLArray{promote_type(X,Y)}", f = ((x, y) -> JLArray([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = true), + (name = "JLArray{X}", f = (x -> JLArray([x])), N = 1, mutable = true, typed = true, bitsonly = true), + (name = "JLArray{promote_type(X, Y)}", f = ((x, y) -> JLArray([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = true), ] #! format: on @@ -248,10 +248,10 @@ _memory(x::Vector) = Memory{eltype(x)}(x) #! format: off push!( wrappers, - (name = "Memory{X}", f = (x -> _memory([x])), N = 1, mutable = true, typed = true, bitsonly = false), - (name = "Memory{Any}", f = (x -> _memory(Any[x])), N = 1, mutable = true, typed = false, bitsonly = false), - (name = "Memory{promote_type(X,Y)}", f = ((x, y) -> _memory([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = false), - (name = "Memory{Any}", f = ((x, y) -> _memory(Any[x, y])), N = 2, mutable = true, typed = false, bitsonly = false), + (name = "Memory{X}", f = (x -> _memory([x])), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Memory{Any}", f = (x -> _memory(Any[x])), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Memory{promote_type(X, Y)}", f = ((x, y) -> _memory([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "Memory{Any}", f = ((x, y) -> _memory(Any[x, y])), N = 2, mutable = true, typed = false, bitsonly = false), ) #! format: on end From 42e128e42b1ab5db6d21083f72f3526c004a98be Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 21 Jan 2025 20:09:58 -0800 Subject: [PATCH 14/17] One more tweak to format on/off tags --- test/recursive_maps.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 1974076160..6713bb9870 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -240,12 +240,10 @@ const wrappers = [ (name = "JLArray{X}", f = (x -> JLArray([x])), N = 1, mutable = true, typed = true, bitsonly = true), (name = "JLArray{promote_type(X, Y)}", f = ((x, y) -> JLArray([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = true), ] -#! format: on @static if VERSION < v"1.11-" else _memory(x::Vector) = Memory{eltype(x)}(x) -#! format: off push!( wrappers, (name = "Memory{X}", f = (x -> _memory([x])), N = 1, mutable = true, typed = true, bitsonly = false), @@ -253,8 +251,8 @@ push!( (name = "Memory{promote_type(X, Y)}", f = ((x, y) -> _memory([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = false), (name = "Memory{Any}", f = ((x, y) -> _memory(Any[x, y])), N = 2, mutable = true, typed = false, bitsonly = false), ) -#! format: on end +#! format: on function test_make_zero() @testset "scalars" begin From 4cc1ed06c1e7be8bc5ae9b958da81927b55cc8f4 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 22 Jan 2025 13:18:46 -0800 Subject: [PATCH 15/17] Fix formatting --- src/typeutils/recursive_add.jl | 4 +- src/typeutils/recursive_maps.jl | 175 +++++++++++++++++--------------- test/recursive_maps.jl | 44 ++++---- 3 files changed, 119 insertions(+), 104 deletions(-) diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 368eca3d09..e8fe633b38 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -19,7 +19,9 @@ types, such that `zi = xi + f(yi)` applies to differentiable values, while `zi = to non-differentiable values. If a custom callable is passed, it is combined with the default, as `recursive_add` is not generally capable of traversing inactive objects. """ -function recursive_add(x::T, y::T, f::F = identity, forcelhs::L = guaranteed_const) where {T, F, L} +function recursive_add( + x::T, y::T, f::F = identity, forcelhs::L = guaranteed_const + ) where {T, F, L} function addf(xi::S, yi::S) where {S} @assert EnzymeCore.isvectortype(S) return ((xi + f(yi))::S,) diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index d97090869c..b1745157c9 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -47,37 +47,37 @@ recommended for non-interactive usage and is the default. The updating constructor `InactiveConfig(config::InactiveConfig, extra)` returns a new config that extends `config` with an additional `extra` function. """ -struct InactiveConfig{copy_if_inactive,runtime_inactive,E} +struct InactiveConfig{copy_if_inactive, runtime_inactive, E} extra::E - function InactiveConfig{C,R}(extra::E) where {C,R,E} + function InactiveConfig{C, R}(extra::E) where {C, R, E} @assert Base.issingletontype(E) - return new{C::Bool,R::Bool,E}(extra) + return new{C::Bool, R::Bool, E}(extra) end end function InactiveConfig( - extra::E=(_ -> (@nospecialize; false)); - copy_if_inactive::Val{C}=Val(false), runtime_inactive::Val{R}=Val(false), -) where {E,C,R} - return InactiveConfig{C,R}(extra) + extra::E = (_ -> (@nospecialize; false)); + copy_if_inactive::Val{C} = Val(false), runtime_inactive::Val{R} = Val(false), + ) where {E, C, R} + return InactiveConfig{C, R}(extra) end -function InactiveConfig(config::InactiveConfig{C,R}, extra::E) where {C,R,E} +function InactiveConfig(config::InactiveConfig{C, R}, extra::E) where {C, R, E} @inline combinedextra(::Type{T}) where {T} = (config.extra(T) || extra(T)) - return InactiveConfig{C,R}(combinedextra) + return InactiveConfig{C, R}(combinedextra) end -function isinactivetype(::Type{T}, config::InactiveConfig{C,false}) where {T,C} +function isinactivetype(::Type{T}, config::InactiveConfig{C, false}) where {T, C} return guaranteed_const(T) || config.extra(T) # call guaranteed_const first, as this is a constant at runtime end -function isinactivetype(::Type{T}, config::InactiveConfig{C,true}) where {T,C} +function isinactivetype(::Type{T}, config::InactiveConfig{C, true}) where {T, C} return config.extra(T) || guaranteed_const_nongen(T, nothing) # call config.extra first, as guaranteed_const_nongen may incur runtime dispatch end -function isnonactivetype(::Type{T}, config::InactiveConfig{C,false}) where {T,C} +function isnonactivetype(::Type{T}, config::InactiveConfig{C, false}) where {T, C} return guaranteed_nonactive(T) || config.extra(T) # call guaranteed_const first, as this is a constant at runtime end -function isnonactivetype(::Type{T}, config::InactiveConfig{C,true}) where {T,C} +function isnonactivetype(::Type{T}, config::InactiveConfig{C, true}) where {T, C} return config.extra(T) || guaranteed_nonactive_nongen(T, nothing) # call config.extra first, as guaranteed_nonactive_nongen may incur runtime dispatch end @@ -198,14 +198,14 @@ array that the type notionally represents. function recursive_map end ## type alias for unified handling of out-of-place and partially-in-place recursive_map -const YS{Nout,T} = Union{Val{Nout},NTuple{Nout,T}} +const YS{Nout, T} = Union{Val{Nout}, NTuple{Nout, T}} @inline hasvalues(::Val{Nout}) where {Nout} = (Nout::Int; false) @inline hasvalues(::NTuple) = true ## main entry point: set default arguments, allocate IdDict if needed, exit early if possible function recursive_map( - f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config::InactiveConfig=InactiveConfig() -) where {F,Nout,Nin,T} + f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config::InactiveConfig = InactiveConfig() + ) where {F, Nout, Nin, T} @assert (Nout == 1) || (Nout == 2) newys = if isinactivetype(T, config) recursive_map_inactive(nothing, ys, xs, config) @@ -214,17 +214,17 @@ function recursive_map( else recursive_map_inner(IdDict(), f, ys, xs, config) end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end ## recursive methods function recursive_map( - seen::Union{Nothing,IdDict}, - f::F, - ys::YS{Nout,T}, - xs::NTuple{Nin,T}, - config::InactiveConfig=InactiveConfig(), -) where {F,Nout,Nin,T} + seen::Union{Nothing, IdDict}, + f::F, + ys::YS{Nout, T}, + xs::NTuple{Nin, T}, + config::InactiveConfig = InactiveConfig(), + ) where {F, Nout, Nin, T} # determine whether to continue recursion, copy/share, or retrieve from cache @assert (Nout == 1) || (Nout == 2) newys = if isinactivetype(T, config) @@ -236,12 +236,12 @@ function recursive_map( else recursive_map_inner(seen, f, ys, xs, config) end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end @inline function recursive_map_inner( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} # forward to appropriate handler for leaf vs. mutable vs. immutable type @assert !isabstracttype(T) @assert isconcretetype(T) @@ -252,12 +252,12 @@ end else recursive_map_immutable(seen, f, ys, xs, config) end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end @inline function recursive_map_mutable( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} @assert ismutabletype(T) if !hasvalues(ys) && !(T <: DenseArray) && all(isbitstype, fieldtypes(T)) # fast path for out-of-place handling when all fields are bitstypes, which rules @@ -274,12 +274,12 @@ end maybecache!(seen, newys, xs) recursive_map_mutable_inner!(seen, f, newys, ys, xs, config) end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end @inline function recursive_map_mutable_inner!( - seen, f::F, newys::NTuple{Nout,T}, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T<:DenseArray} + seen, f::F, newys::NTuple{Nout, T}, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T <: DenseArray} if (Nout == 1) && isbitstype(eltype(T)) newy = only(newys) if hasvalues(ys) @@ -329,12 +329,12 @@ end # maybecache! _should_ be a no-op here; call it anyway for consistency maybecache!(seen, newys, xs) end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end @generated function recursive_map_immutable_inner( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} nf = fieldcount(T) return quote @inline @@ -367,13 +367,13 @@ end newys = Base.@ntuple $Nout j -> begin $(Expr(:splatnew, :T, :(Base.@ntuple $nf i -> newys_i[j]))) end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end end Base.@propagate_inbounds function recursive_map_item!( - i, seen, f::F, newys::NTuple{Nout,T}, ys::YS{Nout,T}, xs::NTuple{Nin,T}, config -) where {F,Nout,Nin,T} + i, seen, f::F, newys::NTuple{Nout, T}, ys::YS{Nout, T}, xs::NTuple{Nin, T}, config + ) where {F, Nout, Nin, T} if isinitialized(first(xs), i) check_allinitialized(Base.tail(xs), i) newys_i = recursive_map_item(i, seen, f, ys, xs, config) @@ -401,50 +401,50 @@ end # function barriers such that abstractly typed items trigger minimal runtime dispatch function recursive_map_barrier( - seen, f::F, ::Val{Nout}, config::InactiveConfig, xs_i::Vararg{ST,Nin} -) where {F,Nout,Nin,ST} - return recursive_map(seen, f, Val(Nout), xs_i, config)::NTuple{Nout,ST} + seen, f::F, ::Val{Nout}, config::InactiveConfig, xs_i::Vararg{ST, Nin} + ) where {F, Nout, Nin, ST} + return recursive_map(seen, f, Val(Nout), xs_i, config)::NTuple{Nout, ST} end function recursive_map_barrier!!( - seen, f::F, y_i::ST, config::InactiveConfig, xs_i::Vararg{ST,Nin} -) where {F,Nin,ST} - return recursive_map(seen, f, (y_i,), xs_i, config)::NTuple{1,ST} + seen, f::F, y_i::ST, config::InactiveConfig, xs_i::Vararg{ST, Nin} + ) where {F, Nin, ST} + return recursive_map(seen, f, (y_i,), xs_i, config)::NTuple{1, ST} end function recursive_map_barrier!!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - seen, f::F, y1_i::ST, y2_i::ST, config::InactiveConfig, xs_i::Vararg{ST,Nin} -) where {F,Nin,ST} + seen, f::F, y1_i::ST, y2_i::ST, config::InactiveConfig, xs_i::Vararg{ST, Nin} + ) where {F, Nin, ST} ys_i = (y1_i, y2_i) - return recursive_map(seen, f, ys_i, xs_i, config)::NTuple{2,ST} + return recursive_map(seen, f, ys_i, xs_i, config)::NTuple{2, ST} end ## recursion base case handlers @inline function recursive_map_leaf( - seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T} -) where {F,Nout,Nin,T} + seen, f::F, ys::YS{Nout, T}, xs::NTuple{Nin, T} + ) where {F, Nout, Nin, T} # apply the mapped function to leaf values if !hasvalues(ys) || isbitstype(T) || isscalartype(T) - newys = f(xs...)::NTuple{Nout,T} + newys = f(xs...)::NTuple{Nout, T} else # !isbitstype(T) - newys = f(ys..., xs...)::NTuple{Nout,T} + newys = f(ys..., xs...)::NTuple{Nout, T} if ismutabletype(T) @assert newys === ys end end maybecache!(seen, newys, xs) - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end @inline function recursive_map_inactive( - _, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::InactiveConfig{copy_if_inactive} -) where {Nout,Nin,T,copy_if_inactive} - return ys::NTuple{Nout,T} + _, ys::NTuple{Nout, T}, xs::NTuple{Nin, T}, ::InactiveConfig{copy_if_inactive} + ) where {Nout, Nin, T, copy_if_inactive} + return ys::NTuple{Nout, T} end @inline function recursive_map_inactive( - seen, ::Val{Nout}, (x1,)::NTuple{Nin,T}, ::InactiveConfig{copy_if_inactive} -) where {Nout,Nin,T,copy_if_inactive} + seen, ::Val{Nout}, (x1,)::NTuple{Nin, T}, ::InactiveConfig{copy_if_inactive} + ) where {Nout, Nin, T, copy_if_inactive} @inline y = if copy_if_inactive && !isbitstype(T) if isnothing(seen) @@ -483,7 +483,10 @@ indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. function recursive_map! end function recursive_map!( - f!!::F, ys::NTuple{Nout, T}, xs::NTuple{Nin, T}, config::InactiveConfig = InactiveConfig() + f!!::F, + ys::NTuple{Nout, T}, + xs::NTuple{Nin, T}, + config::InactiveConfig = InactiveConfig(), ) where {F, Nout, Nin, T} check_nonactive(T, config) newys = recursive_map(f!!, ys, xs, config) @@ -505,7 +508,9 @@ function recursive_map!( end ### recursive_map helpers -@generated function new_structvs(::Type{T}, fields::NTuple{N, Vector{Any}}, nfields_) where {T, N} +@generated function new_structvs( + ::Type{T}, fields::NTuple{N, Vector{Any}}, nfields_ + ) where {T, N} return quote @inline return Base.@ntuple $N j -> begin @@ -532,21 +537,25 @@ Base.@propagate_inbounds function setfield_force!(x::T, i, v) where {T} return nothing end -Base.@propagate_inbounds function getitems((x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i) where {T,N} +Base.@propagate_inbounds function getitems( + (x1, xtail...)::Tuple{T, T, Vararg{T, N}}, i + ) where {T, N} return (getitem(x1, i), getitems(xtail, i)...) end Base.@propagate_inbounds getitems((x1,)::Tuple{T}, i) where {T} = (getitem(x1, i),) Base.@propagate_inbounds function setitems!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - (x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i, (v1, vtail...)::Tuple{ST,ST,Vararg{ST,N}} -) where {T,ST,N} + (x1, xtail...)::Tuple{T, T, Vararg{T, N}}, + i, + (v1, vtail...)::Tuple{ST, ST, Vararg{ST, N}}, + ) where {T, ST, N} setitem!(x1, i, v1) setitems!(xtail, i, vtail) return nothing end -Base.@propagate_inbounds function setitems!((x1,)::Tuple{T}, i, (v1,)::Tuple{ST}) where {T,ST} +Base.@propagate_inbounds function setitems!((x1,)::Tuple{T}, i, (v1,)::Tuple{ST}) where {T, ST} setitem!(x1, i, v1) return nothing end @@ -560,7 +569,9 @@ end @inline shouldcache(::IdDict, ::Type{T}) where {T} = iscachedtype(T) @inline shouldcache(::Nothing, ::Type{T}) where {T} = false -@inline function maybecache!(seen, newys::NTuple{Nout, T}, (x1, xtail...)::NTuple{Nin, T}) where {Nout, Nin, T} +@inline function maybecache!( + seen, newys::NTuple{Nout, T}, (x1, xtail...)::NTuple{Nin, T} + ) where {Nout, Nin, T} if shouldcache(seen, T) seen[x1] = if (Nout == 1) && (Nin == 1) only(newys) @@ -571,24 +582,26 @@ end return nothing end -@inline function hascache(seen, (x1,)::NTuple{Nin,T}) where {Nin,T} +@inline function hascache(seen, (x1,)::NTuple{Nin, T}) where {Nin, T} return shouldcache(seen, T) ? haskey(seen, x1) : false end -@inline function getcached(seen::IdDict, ::Val{Nout}, (x1, xtail...)::NTuple{Nin,T}) where {Nout,Nin,T} +@inline function getcached( + seen::IdDict, ::Val{Nout}, (x1, xtail...)::NTuple{Nin, T} + ) where {Nout, Nin, T} newys = if (Nout == 1) && (Nin == 1) (seen[x1]::T,) else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented - cache = seen[x1]::NTuple{(Nout + Nin - 1),T} - cachedtail = cache[(Nout+1):end] + cache = seen[x1]::NTuple{(Nout + Nin - 1), T} + cachedtail = cache[(Nout + 1):end] check_identical(cachedtail, xtail) # check compatible structure cache[1:Nout] end - return newys::NTuple{Nout,T} + return newys::NTuple{Nout, T} end ## argument validation -Base.@propagate_inbounds function check_initialized(x, i, initialized=true) +Base.@propagate_inbounds function check_initialized(x, i, initialized = true) if isinitialized(x, i) != initialized throw_initialized() # TODO: hit this when VectorSpace implemented end @@ -596,21 +609,21 @@ Base.@propagate_inbounds function check_initialized(x, i, initialized=true) end Base.@propagate_inbounds function check_allinitialized( # TODO: hit this when VectorSpace implemented - (x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i, initialized=true -) where {T,N} + (x1, xtail...)::Tuple{T, T, Vararg{T, N}}, i, initialized = true + ) where {T, N} check_initialized(x1, i, initialized) check_allinitialized(xtail, i, initialized) return nothing end Base.@propagate_inbounds function check_allinitialized( - (x1,)::Tuple{T}, i, initialized=true -) where {T} + (x1,)::Tuple{T}, i, initialized = true + ) where {T} check_initialized(x1, i, initialized) return nothing end -Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized=true) = nothing +Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized = true) = nothing @inline function check_identical(u, v) # TODO: hit this when VectorSpace implemented if u !== v @@ -670,21 +683,21 @@ end return nothing end -@inline function _make_zero_inner!(val, args::Vararg{Any,M}; kws...) where {M} +@inline function _make_zero_inner!(val, args::Vararg{Any, M}; kws...) where {M} return recursive_map!(_make_zero!!, (val,), (val,), make_zero!_config(args...; kws...)) end -@inline function _make_zero_inner!(val, seen::IdDict, args::Vararg{Any,M}; kws...) where {M} +@inline function _make_zero_inner!(val, seen::IdDict, args::Vararg{Any, M}; kws...) where {M} config = make_zero!_config(args...; kws...) return recursive_map!(seen, _make_zero!!, (val,), (val,), config) end # map make_zero(!) args/kws to config -@inline make_zero_config(C) = InactiveConfig(; copy_if_inactive=C) -@inline make_zero_config(C, R) = InactiveConfig(; copy_if_inactive=C, runtime_inactive=R) +@inline make_zero_config(C) = InactiveConfig(; copy_if_inactive = C) +@inline make_zero_config(C, R) = InactiveConfig(; copy_if_inactive = C, runtime_inactive = R) @inline make_zero_config(; kws...) = InactiveConfig(; kws...) -@inline make_zero!_config(R) = InactiveConfig(; runtime_inactive=R) -@inline function make_zero!_config(; runtime_inactive=nothing) +@inline make_zero!_config(R) = InactiveConfig(; runtime_inactive = R) +@inline function make_zero!_config(; runtime_inactive = nothing) if isnothing(runtime_inactive) return InactiveConfig() else diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl index 6713bb9870..a45bd71fc3 100644 --- a/test/recursive_maps.jl +++ b/test/recursive_maps.jl @@ -60,12 +60,12 @@ Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) getx(a::MutableWrapper) = a.x setx!(a::MutableWrapper, x) = (a.x = x) -struct DualWrapper{Tx,Ty} +struct DualWrapper{Tx, Ty} x::Tx y::Ty end -DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T, typeof(y)}(x, y) function Base.:(==)(a::DualWrapper, b::DualWrapper) return (a === b) || ((a.x == b.x) && (a.y == b.y)) @@ -74,12 +74,12 @@ end getx(a::DualWrapper) = a.x gety(a::DualWrapper) = a.y -mutable struct MutableDualWrapper{Tx,Ty} +mutable struct MutableDualWrapper{Tx, Ty} x::Tx y::Ty end -MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T, typeof(y)}(x, y) function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) return (a === b) || ((a.x == b.x) && (a.y == b.y)) @@ -91,14 +91,14 @@ gety(a::MutableDualWrapper) = a.y setx!(a::MutableDualWrapper, x) = (a.x = x) sety!(a::MutableDualWrapper, y) = (a.y = y) -struct Incomplete{T,U} +struct Incomplete{T, U} s::String x::Float64 w::T y::U # possibly not initialized z # not initialized - Incomplete(s, x, w) = new{typeof(w),Any}(s, x, w) - Incomplete(s, x, w, y) = new{typeof(w),typeof(y)}(s, x, w, y) + Incomplete(s, x, w) = new{typeof(w), Any}(s, x, w) + Incomplete(s, x, w, y) = new{typeof(w), typeof(y)}(s, x, w, y) end function Base.:(==)(a::Incomplete, b::Incomplete) @@ -564,11 +564,11 @@ function test_make_zero() @assert !EnzymeRules.inactive_type(typeof(a)) a_makez = make_zero(a, Val(false), Val(false)) @assert a_makez == MutableWrapper(0.0) - a_makez = make_zero(a; runtime_inactive=Val(false)) + a_makez = make_zero(a; runtime_inactive = Val(false)) @assert a_makez == MutableWrapper(0.0) a_makez = make_zero(a, Val(false), Val(true)) @assert a_makez == MutableWrapper(0.0) - a_makez = make_zero(a; runtime_inactive=Val(true)) + a_makez = make_zero(a; runtime_inactive = Val(true)) @assert a_makez == MutableWrapper(0.0) # mark MutableWrapper as inactive @@ -585,14 +585,14 @@ function test_make_zero() a_makez = @invokelatest make_zero(a, Val(false), Val(true)) @test a_makez === a a_makez = @invokelatest make_zero( - a; copy_if_inactive=Val(false), runtime_inactive=Val(true) + a; copy_if_inactive = Val(false), runtime_inactive = Val(true) ) @test a_makez === a a_makez = @invokelatest make_zero(a, Val(true), Val(true)) @test a_makez !== a @test a_makez == MutableWrapper(1.0) a_makez = @invokelatest make_zero( - a; copy_if_inactive=Val(true), runtime_inactive=Val(true) + a; copy_if_inactive = Val(true), runtime_inactive = Val(true) ) @test a_makez !== a @test a_makez == MutableWrapper(1.0) @@ -603,11 +603,11 @@ function test_make_zero() # verify that MutableWrapper is seen as active by both variants a_makez = @invokelatest make_zero(a, Val(false), Val(false)) @test a_makez == MutableWrapper(0.0) - a_makez = @invokelatest make_zero(a; runtime_inactive=Val(false)) + a_makez = @invokelatest make_zero(a; runtime_inactive = Val(false)) @test a_makez == MutableWrapper(0.0) a_makez = @invokelatest make_zero(a, Val(false), Val(true)) @test a_makez == MutableWrapper(0.0) - a_makez = @invokelatest make_zero(a; runtime_inactive=Val(true)) + a_makez = @invokelatest make_zero(a; runtime_inactive = Val(true)) @test a_makez == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin @@ -735,13 +735,13 @@ function test_make_zero!() end @testset "inactive" begin @testset "in $(wrapper.name)" for wrapper in filter( - w -> (w.mutable || (w.typed == true)), wrappers - ) + w -> (w.mutable || (w.typed == true)), wrappers + ) if wrapper.N == 1 for (inactive, condition) in [ - (inactivebits, true), - (inactivearr, !wrapper.bitsonly), - ] + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] condition || continue w = wrapper.f(inactive) make_zero!(w) @@ -880,13 +880,13 @@ function test_make_zero!() make_zero!(a, Val(false)) @assert a == MutableWrapper(0.0) a.x = 1.0 - make_zero!(a; runtime_inactive=Val(false)) + make_zero!(a; runtime_inactive = Val(false)) @assert a == MutableWrapper(0.0) a.x = 1.0 make_zero!(a, Val(true)) @assert a == MutableWrapper(0.0) a.x = 1.0 - make_zero!(a; runtime_inactive=Val(true)) + make_zero!(a; runtime_inactive = Val(true)) @assert a == MutableWrapper(0.0) # mark MutableWrapper as inactive @@ -917,13 +917,13 @@ function test_make_zero!() @invokelatest make_zero!(a, Val(true)) @test a == MutableWrapper(0.0) a.x = 1.0 - @invokelatest make_zero!(a; runtime_inactive=Val(true)) + @invokelatest make_zero!(a; runtime_inactive = Val(true)) @test a == MutableWrapper(0.0) a.x = 1.0 @invokelatest make_zero!(a, Val(false)) @test a == MutableWrapper(0.0) a.x = 1.0 - @invokelatest make_zero!(a; runtime_inactive=Val(false)) + @invokelatest make_zero!(a; runtime_inactive = Val(false)) @test a == MutableWrapper(0.0) end @testset "undefined fields/unassigned elements" begin From 94fdb995d4b3d487803f330d2b7f0ba6ea0874f1 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 22 Jan 2025 13:32:36 -0800 Subject: [PATCH 16/17] Final formatting fix --- test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4b55ab5882..f5a961ad72 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -600,10 +600,10 @@ end return 0.0+1.0im end @testset "setinact2 Const" reverse_holomorphic_array_tests( - setinact2, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + setinact2, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Const, mapf = false ) @testset "setinact2 Active" reverse_holomorphic_array_tests( - setinact2, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Active, mapf=false + setinact2, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Active, mapf = false ) function setact(z) @@ -611,10 +611,10 @@ end return z[1][1] # returns scalar for both [x] and [(x,)] end @testset "setact Const" reverse_holomorphic_array_tests( - setact, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + setact, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Const, mapf = false ) @testset "setact Active" reverse_holomorphic_array_tests( - setact, 3.4 + 2.7im, 2.0; val_expected=2(3.4 + 2.7im), ret=Active, mapf=false + setact, 3.4 + 2.7im, 2.0; val_expected = 2(3.4 + 2.7im), ret = Active, mapf = false ) function upgrade(z) From 3de96f05f50a427f5148baa44892ab69ca70af8a Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 22 Jan 2025 13:48:16 -0800 Subject: [PATCH 17/17] Make Documenter happy Give new docstrings a home in the manual --- docs/src/internal_api.md | 2 +- src/typeutils/recursive_add.jl | 12 ------------ src/typeutils/recursive_maps.jl | 12 ------------ 3 files changed, 1 insertion(+), 25 deletions(-) diff --git a/docs/src/internal_api.md b/docs/src/internal_api.md index bc26e86a44..61cff49ee0 100644 --- a/docs/src/internal_api.md +++ b/docs/src/internal_api.md @@ -7,6 +7,6 @@ without deprecation. ```@autodocs -Modules = [Enzyme.Compiler] +Modules = [Enzyme.Compiler, Enzyme.Compiler.RecursiveMaps] Order = [:module, :type, :constant, :macro, :function] ``` diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index e8fe633b38..45894c9c7c 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -3,10 +3,6 @@ using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map! """ recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const) -!!! warning - Internal function, documented for developer convenience but not covered by semver API - stability guarantees - Recursively construct `z::T` such that `zi = xi + f(yi)` where `zi`, `xi`, and `yi` are corresponding values from `z`, `x`, and `y`. In other words, this is a recursive generalization of `x .+ f.(y)`. @@ -37,10 +33,6 @@ end f, seen::IdDict, config::RecursiveMaps.InactiveConfig=RecursiveMaps.InactiveConfig() ) -!!! warning - Internal function, documented for developer convenience but not covered by semver API - stability guarantees - Recursively accumulate from values into keys, generalizing `key .+= f.(value)` to arbitrary types. This accumulation is applied to each key-value pair in `seen::IdDict` where each key is of a mutable or non-isbits vector type and the corresponding value is of the same type @@ -93,10 +85,6 @@ end """ accumulate_into!(into::T, from::T) -!!! warning - Internal function, documented for developer convenience but not covered by semver API - stability guarantees - Recursively accumulate from `from` into `into` and zero `from`, such that `into_i += from_i` and `from_i = 0`, where `into_i` and `from_i` are corresponding values within `into` and `from`. In other words, this is a recursive generalization of diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl index b1745157c9..d54cdc9923 100644 --- a/src/typeutils/recursive_maps.jl +++ b/src/typeutils/recursive_maps.jl @@ -12,10 +12,6 @@ using ..Compiler: guaranteed_const, guaranteed_const_nongen, guaranteed_nonactiv config = InactiveConfig{copy_if_inactive::Bool,runtime_inactive::Bool}(extra) newconfig = InactiveConfig(config::InactiveConfig, extra) -!!! warning - Internal type, documented for developer convenience but not covered by semver API - stability guarantees - Config type for specifying which parts of objects should be skipped by `recursive_map{!}`. At a minimum, parts that Enzyme always considers inactive are skipped. An inactive type is a @@ -110,10 +106,6 @@ end config::InactiveConfig=InactiveConfig(), )::T -!!! warning - Internal function, documented for developer convenience but not covered by semver API - stability guarantees - Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping the function `f` over every differentiable value encountered and building `Nout` new objects `(y1::T, ...)` from the resulting values `(y1_i, ...) = f(x1_i, ..., xNin_i)`. Only @@ -468,10 +460,6 @@ end isinactivetype::InactiveConfig=InactiveConfig(), )::Nothing -!!! warning - Internal function, documented for developer convenience but not covered by semver API - stability guarantees - Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping the function `f!!` over every differentiable value encountered and updating `(y1::T, ...)` in-place with the resulting values.