Skip to content

Commit

Permalink
Merge branch 'main' into moreblastests
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen authored Sep 26, 2023
2 parents 44082d8 + a957410 commit a5f78a0
Show file tree
Hide file tree
Showing 15 changed files with 450 additions and 212 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.11.7"
version = "0.11.8"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -18,9 +18,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
CEnum = "0.4"
EnzymeCore = "0.5.2"
Enzyme_jll = "0.0.83"
GPUCompiler = "0.21, 0.22, 0.23"
EnzymeCore = "0.6.0"
Enzyme_jll = "0.0.86"
GPUCompiler = "0.21, 0.22, 0.23, 0.24"
LLVM = "6.1"
ObjectFile = "0.4"
Preferences = "1.4"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.5.2"
version = "0.6.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Enzyme will auto-differentiate in respect `Active` arguments.
struct Active{T} <: Annotation{T}
val::T
@inline Active(x::T1) where {T1} = new{T1}(x)
@inline Active(x::T1) where {T1 <: AbstractArray} = error("Unsupported Active{"*string(T1)*"}, consider Duplicated or Const")
@inline Active(x::T1) where {T1 <: Array} = error("Unsupported Active{"*string(T1)*"}, consider Duplicated or Const")
end
Adapt.adapt_structure(to, x::Active) = Active(adapt(to, x.val))

Expand Down
7 changes: 7 additions & 0 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,11 @@ function is_inactive_noinl_from_sig(@nospecialize(TT);
return isapplicable(inactive_noinl, TT; world, method_table, caller)
end

"""
inactive_type(::Type{Ty})
Mark a particular type `Ty` as always being inactive.
"""
inactive_type(::Type{Ty}) where Ty = false

end # EnzymeRules
4 changes: 2 additions & 2 deletions lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.1.0"
version = "0.1.3"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand All @@ -14,7 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
ConstructionBase = "1.4.1"
Enzyme = "0.11"
EnzymeCore = "0.5"
EnzymeCore = "0.5, 0.6"
FiniteDifferences = "0.12.12"
MetaTesting = "0.1"
Quaternions = "0.7"
Expand Down
27 changes: 21 additions & 6 deletions lib/EnzymeTestUtils/src/finite_difference_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
_fd_forward(fdm, f, ::Type{<:Const}, y, activities) = ()

#=
_fd_reverse(fdm, f, ȳ, activities)
_fd_reverse(fdm, f, ȳ, activities, active_return)
Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `activities`.
Expand All @@ -51,14 +51,14 @@ Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `ac
- `f`: The function to differentiate.
- `ȳ`: The cotangent of the primal output `y=f(xs...)`.
- `activities`: activities that would be passed to `Enzyme.autodiff`
- `active_return`: whether the return is non-constant
# Returns
- `x̄s`: Derivatives of output `s` w.r.t. `xs` estimated by finite differencing.
=#
function _fd_reverse(fdm, f, ȳ, activities)
function _fd_reverse(fdm, f, ȳ, activities, active_return)
xs = map(x -> x.val, activities)
ignores = map(a -> a isa Const, activities)
f2 = _wrap_reverse_function(f, xs, ignores)
f2 = _wrap_reverse_function(active_return, f, xs, ignores)
all(ignores) && return map(zero_tangent, xs)
ignores = collect(ignores)
is_batch = _any_batch_duplicated(map(typeof, activities)...)
Expand Down Expand Up @@ -137,7 +137,7 @@ All arguments are copied before being passed to `f`, so that `fnew` is non-mutat
- `ignores`: Collection of `Bool`s, the same length as `xs`.
If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === NoTangent()`.
=#
function _wrap_reverse_function(f, xs, ignores)
function _wrap_reverse_function(active_return, f, xs, ignores)
function fnew(sigargs...)
callargs = Any[]
retargs = Any[]
Expand All @@ -156,7 +156,22 @@ function _wrap_reverse_function(f, xs, ignores)
@assert j == length(sigargs) + 1
@assert length(callargs) == length(xs)
@assert length(retargs) == count(!, ignores)
return (deepcopy(f)(callargs...), retargs...)

# if an arg and a return alias, do not consider the contribution from the arg as returned here,
# it will already be taken into account. This is implemented using the deepcopy_internal, which
# will add all objects inside the return into the dict `zeros`.
zeros = IdDict()
origRet = Base.deepcopy_internal(deepcopy(f)(callargs...), zeros)

# we will now explicitly zero all objects returned, and replace any of the args with this
# zero, if the input and output alias.
if active_return
for k in keys(zeros)
zeros[k] = zero_tangent(k)
end
end

return (origRet, Base.deepcopy_internal(retargs, zeros)...)
end
return fnew
end
42 changes: 35 additions & 7 deletions lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
@inline call_with_kwargs(fkwargs::NT, f::FT, xs...) where {NT, FT} = f(xs...; fkwargs...)

# Force evaluation to avoid problem of a tuple being created but not being SROA'd
# Can cause some tests to unnecessarily fail without runtime activity
for N in 1:30
argexprs = [Symbol(:arg, Symbol(i)) for i in 1:N]
eval(quote
function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT}
Base.@_inline_meta
@static if VERSION v"1.8"
# callsite inline syntax unsupported in <= 1.8
f($(argexprs...); fkwargs...)
else
@inline f($(argexprs...); fkwargs...)
end
end
end)
end

"""
test_reverse(f, Activity, args...; kwargs...)
Expand Down Expand Up @@ -43,6 +62,8 @@ end
Here we test a rule for a function of an array in batch reverse-mode:
```julia
x = randn(3)
for Tret in (Const, Active), Tx in (Const, BatchDuplicated)
Expand All @@ -61,7 +82,7 @@ function test_reverse(
testset_name=nothing,
)
call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
call_with_kwargs(f, xs...) = f(xs...; fkwargs...)
call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...)
if testset_name === nothing
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
end
Expand All @@ -82,22 +103,29 @@ function test_reverse(
end
end
# call finitedifferences, avoid mutating original arguments
dx_fdm = _fd_reverse(fdm, call_with_kwargs, ȳ, activities)
dx_fdm = _fd_reverse(fdm, call_with_captured_kwargs, ȳ, activities, !(ret_activity <: Const))
# call autodiff, allow mutating original arguments
c_act = Const(call_with_kwargs)
forward, reverse = autodiff_thunk(
ReverseSplitWithPrimal, typeof(c_act), ret_activity, map(typeof, activities)...
ReverseSplitWithPrimal, typeof(c_act), ret_activity, typeof(Const(fkwargs)), map(typeof, activities)...
)
tape, y_ad, shadow_result = forward(c_act, activities...)
tape, y_ad, shadow_result = forward(c_act, Const(fkwargs), activities...)
if ret_activity <: Active
dx_ad = only(reverse(c_act, activities..., ȳ, tape))
dx_ad = only(reverse(c_act, Const(fkwargs), activities..., ȳ, tape))
else
# if there's a shadow result, then we need to set it to our random adjoint
if !(shadow_result === nothing)
map_fields_recursive(copyto!, shadow_result, ȳ)
if !_any_batch_duplicated(map(typeof, activities)...)
map_fields_recursive(copyto!, shadow_result, ȳ)
else
for (sr, dy) in zip(shadow_result, ȳ)
map_fields_recursive(copyto!, sr, dy)
end
end
end
dx_ad = only(reverse(c_act, activities..., tape))
dx_ad = only(reverse(c_act, Const(fkwargs), activities..., tape))
end
dx_ad = (dx_ad[1], dx_ad[3:end]...)
test_approx(
y_ad, y, "The return value of the rule and function must agree"; atol, rtol
)
Expand Down
4 changes: 2 additions & 2 deletions lib/EnzymeTestUtils/test/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
atol = rtol = sqrt(eps(real(T)))
@test !fails() do
test_reverse(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol)
end broken = (VERSION < v"1.8" && Tx <: Const && T <: Complex)
end
end
end

Expand Down Expand Up @@ -121,7 +121,7 @@ end
atol = rtol = sqrt(eps(real(T)))
# https://github.com/EnzymeAD/Enzyme.jl/issues/877
test_broken = (
(VERSION > v"1.8" && T <: Real && !(Tc <: Const && Ty <: Const)) ||
(VERSION > v"1.8" && T <: Real) ||
(VERSION < v"1.8" && Tc <: Const)
)
if Tc <: BatchDuplicated && Ty <: BatchDuplicated
Expand Down
12 changes: 6 additions & 6 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ function EnzymeCreatePrimalAndGradient(logic, todiff, retType, constant_args, TA
uncacheable_args, augmented, atomicAdd)
freeMemory = true
ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef,
(EnzymeLogicRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
(EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo,
Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8),
logic, todiff, retType, constant_args, length(constant_args), TA, returnValue,
logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue,
dretUsed, mode, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args),
augmented, atomicAdd)
end
Expand All @@ -140,10 +140,10 @@ function EnzymeCreateForwardDiff(logic, todiff, retType, constant_args, TA,
freeMemory = true
aug = C_NULL
ccall((:EnzymeCreateForwardDiff, libEnzyme), LLVMValueRef,
(EnzymeLogicRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
(EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo,
Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr),
logic, todiff, retType, constant_args, length(constant_args), TA, returnValue,
logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue,
mode, freeMemory, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug)
end

Expand All @@ -162,10 +162,10 @@ function EnzymeCreateAugmentedPrimal(logic, todiff, retType, constant_args, TA,
shadowReturnUsed,
typeInfo, uncacheable_args, forceAnonymousTape, width, atomicAdd)
ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr,
(EnzymeLogicRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
(EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, UInt8, UInt8,
CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, Cuint, UInt8),
logic, todiff, retType, constant_args, length(constant_args), TA, returnUsed,
logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnUsed,
shadowReturnUsed,
typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, width, atomicAdd)
end
Expand Down
Loading

0 comments on commit a5f78a0

Please sign in to comment.