Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursive map generalizing the make_zero mechanism #1852

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/internal_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
without deprecation.

```@autodocs
Modules = [Enzyme.Compiler]
Modules = [Enzyme.Compiler, Enzyme.Compiler.RecursiveMaps]
Order = [:module, :type, :constant, :macro, :function]
```
49 changes: 5 additions & 44 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,11 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
return Base.zero(prev)::FT
end

@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive}
if haskey(seen, prev)
return seen[prev]
end
new = Base.zero(prev)::FT
seen[prev] = new
return new
end

@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT, seen
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(T))
return nothing
end
@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
Enzyme.EnzymeCore.make_zero!(prev, nothing)
return nothing
# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct,
# but in case their dedicated `zero` and `fill!` methods are more efficient than
# `make_zero(!)`s recursion, we opt into treating them as leaves.
@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S, T}}) where {S, T}
return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T)
end

end
135 changes: 124 additions & 11 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,28 +506,141 @@ function autodiff_thunk end
function autodiff_deferred_thunk end

"""
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T

Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
make_zero(prev::T; copy_if_inactive=Val(false), runtime_inactive=Val(false))::T
make_zero(prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}])::T
make_zero(
::Type{T}, seen::IdDict, prev::T;
copy_if_inactive=Val(false), runtime_inactive=Val(false),
)::T
make_zero(
::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}]
)::T

Recursively make a copy of the value `prev::T` in which all differentiable values are zeroed.

The argument `copy_if_inactive` specifies what to do if the type `T` or any
of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s
instance (if `Val(false)`, the default) or make a copy (if `Val(true)`).

The argument `runtime_inactive` specifies whether this function should respect runtime
semantics when determining if a type is guaranteed inactive. If `Val(true)`, changes and
additions to the methods of `EnzymeRules.inactive_type` will be reflected in the behavior of
this function. If `Val(false)`, the inactivity of a type is determined once per Julia
session and reused in every subsequent call with `runtime_inactive=Val(false)`, regardless
of changes to `EnzymeRules.inactive_type` (in technical terms, there will be no invalidation
when changing `EnzymeRules.inactive_type`). Using `runtime_inactive = Val(false)` may be
desireable in interactive sessions, but can sometimes impose a performance penalty and may
in rare cases break gradient compilation when used inside custom rules. Hence
`runtime_inactive = Val(true)` is preferred in non-interactive usage and is the default.

`copy_if_inactive` and `runtime_inactive` may be given as either positional or keywords
arguments, but not a combination.

Extending this method for custom types is rarely needed. If you implement a new type, such
as a GPU array type, on which `make_zero` should directly invoke `zero` when the eltype is
scalar, it is sufficient to implement `Base.zero` and make sure your type subtypes
`DenseArray`. (If subtyping `DenseArray` is not appropriate, extend
[`EnzymeCore.isvectortype`](@ref) instead.)
"""
function make_zero end

"""
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
make_zero!(val::T, [seen::IdDict]; runtime_inactive=Val(true))::Nothing
make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive})::Nothing

Recursively set a variable's differentiable values to zero. Only applicable for types `T`
that are mutable or hold all differentiable values in mutable storage (e.g.,
`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over
parts of `val` that are guaranteed to be inactive.

The argument `runtime_inactive` specifies whether this function should respect runtime
semantics when determining if a type is guaranteed inactive. If `Val(true)`, changes and
additions to the methods of `EnzymeRules.inactive_type` will be reflected in the behavior of
this function. If `Val(false)`, the inactivity of a type is determined once per Julia
session and reused in every subsequent call with `runtime_inactive=Val(false)`, regardless
of changes to `EnzymeRules.inactive_type` (in technical terms, there will be no invalidation
when changing `EnzymeRules.inactive_type`).

Using `runtime_inactive = Val(false)` may be preferred in interactive sessions, but can
sometimes impose a performance penalty and may in rare cases break gradient compilation when
used inside custom rules. Hence `runtime_inactive = Val(true)` is recommended for
non-interactive usage and is the default.

Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
`runtime_inactive` may be given as either a positional or a keyword argument.

Extending this method for custom types is rarely needed. If you implement a new mutable
type, such as a GPU array type, on which `make_zero!` should directly invoke
`fill!(x, false)` when the eltype is scalar, it is sufficient to implement `Base.zero`,
`Base.fill!`, and make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is
not appropriate, extend [`EnzymeCore.isvectortype`](@ref) instead.)
"""
function make_zero! end

"""
make_zero(prev::T)
isvectortype(::Type{T})::Bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new API? It will need a version bump for EnzymeCore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should discuss. The reason I put these helpers in EnzymeCore rather than keeping them internal was that the StaticArrays extension needed to add a method, so I figured there's a chance others might have to do the same for their custom types. However, subtyping DenseArray (and AbstractFloat if that ever becomes relevant) should almost always be sufficient. Either way, the point is only to make these extensible in package extensions. I don't think anyone should ever have to call them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the vector space wrapper functionality I've built on top of this (which will be a separate PR) will probably involve a new type in EnzymeCore, so if that gets accepted there will have to be a new EnzymeCore release anyway

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, absolutely. I think this is the right place to add them, and EnzymeCore is basically meant for people to be able to extend things without having to bite the load time bullet that is Enzyme.

Copy link
Contributor Author

@danielwe danielwe Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw. names are infinitely bikesheddable, both in this case and elsewhere in the PR. My mindset working on this PR is to enable consistent treatment of arbitrary objects as vectors in a space spanned by the scalar (float) values reachable from the object, hence all the vector/scalar terminology, but I don't know if this works well or if it's confusing, especially as part of the public API.


Helper function to recursively make zero.
"""
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
and [`make_zero!`](@ref) recurse through an object.

By default, `isvectortype(T) == true` when `isscalartype(T) == true` or when
`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`.

A new vector type, such as a GPU array type, should normally subtype `DenseArray` and
inherit `isvectortype` that way. However if this is not appropariate, `isvectortype` may be
extended directly as follows:

```julia
@inline function EnzymeCore.isvectortype(::Type{T}) where {T<:NewArray}
U = eltype(T)
return isbitstype(U) && EnzymeCore.isscalartype(U)
end
```

Such a type should implement `Base.zero` and, if mutable, `Base.fill!`.

Extending `isvectortype` is mostly relevant for the lowest-level of abstraction of memory at
which vector space operations like addition and scalar multiplication are supported, the
prototypical case being `Array`. Regular Julia structs with vector space-like semantics
should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
directly on their backing arrays, just like how Enzyme treats them when differentiating. For
example, structured matrix wrappers and sparse array types that are backed by `Array` should
not extend `isvectortype`.

See also [`isscalartype`](@ref).
"""
function isvectortype end

"""
isscalartype(::Type{T})::Bool

Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
values of the type in-place. For example, `BigFloat` is a mutable type but does not support
in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensures that
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]

By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete
types `T <: AbstractFloat`.

A hypothetical new real number type with Enzyme support should usually subtype
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
the function can be extended as follows:

```julia
@inline EnzymeCore.isscalartype(::Type{NewReal}) = true
@inline EnzymeCore.isscalartype(::Type{Complex{NewReal}}) = true
```

In either case, the type should implement `Base.zero`.

See also [`isvectortype`](@ref).

[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
mentioned here only to demonstrate that it would be inappropriate to use traits like
`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing,
showing the need for a dedicated `isscalartype` trait.
"""
function isscalartype end

function tape_type end

Expand Down
8 changes: 2 additions & 6 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,8 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
# compute the correct complex derivative in reverse mode by propagating the conjugate return values
# then subtracting twice the imaginary component to get the correct result

for (k, v) in seen
Compiler.recursive_accumulate(k, v, refn_seed)
end
for (k, v) in seen2
Compiler.recursive_accumulate(k, v, imfn_seed)
end
Compiler.accumulate_seen!(refn_seed, seen)
Compiler.accumulate_seen!(imfn_seed, seen2)

fused = fuse_complex_results(results, args...)

Expand Down
5 changes: 5 additions & 0 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,11 @@
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T}
danielwe marked this conversation as resolved.
Show resolved Hide resolved
rt = Enzyme.Compiler.active_reg_inner(T, (), world)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState

Check warning on line 432 in src/analyses/activity.jl

View check run for this annotation

Codecov / codecov/patch

src/analyses/activity.jl#L430-L432

Added lines #L430 - L432 were not covered by tests
end

"""
Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode)

Expand Down
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ const JuliaGlobalNameMap = Dict{String,Any}(
include("absint.jl")
include("llvm/transforms.jl")
include("llvm/passes.jl")
include("typeutils/make_zero.jl")
include("typeutils/recursive_maps.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
Expand Down
50 changes: 1 addition & 49 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,47 +253,6 @@
return EnzymeRules.AugmentedReturn(primal, shadow, shadow)
end


@inline function accumulate_into(
into::RT,
seen::IdDict,
from::RT,
)::Tuple{RT,RT} where {RT<:Array}
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
seen[into] = (into, from)
for i in eachindex(from)
tup = accumulate_into(into[i], seen, from[i])
@inbounds into[i] = tup[1]
@inbounds from[i] = tup[2]
end
end
return seen[into]
end

@inline function accumulate_into(
into::RT,
seen::IdDict,
from::RT,
)::Tuple{RT,RT} where {RT<:AbstractFloat}
if !haskey(seen, into)
seen[into] = (into + from, RT(0))
end
return seen[into]
end

@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT}
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
throw(AssertionError("Unknown type to accumulate into: $RT"))
end
return seen[into]
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{typeof(Base.deepcopy)},
Expand All @@ -302,15 +261,8 @@
x::Annotation{Ty},
) where {RT,Ty}
if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
accumulate_into(x.dval, IdDict(), shadow)
else
for i = 1:EnzymeRules.width(config)
accumulate_into(x.dval[i], IdDict(), shadow[i])
end
end
Compiler.accumulate_into!(x.dval, shadow)

Check warning on line 264 in src/internal_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/internal_rules.jl#L264

Added line #L264 was not covered by tests
end

return (nothing,)
end

Expand Down
Loading
Loading