Skip to content

Commit

Permalink
Add structural check for view (#985)
Browse files Browse the repository at this point in the history
* Add structural check for view

* with some active of array fixups
  • Loading branch information
wsmoses authored and michel2323 committed Nov 7, 2023
1 parent 36bd03c commit 1ff1a6d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
CEnum = "0.4"
EnzymeCore = "0.5.1"
EnzymeCore = "0.5.2"
Enzyme_jll = "0.0.79"
GPUCompiler = "0.21"
LLVM = "6.1"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.5.1"
version = "0.5.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
42 changes: 42 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Enzyme will auto-differentiate in respect `Active` arguments.
"""
struct Active{T} <: Annotation{T}
val::T
@inline Active(x::T1) where {T1} = new{T1}(x)
@inline Active(x::T1) where {T1 <: AbstractArray} = error("Unsupported Active{"*string(T1)*"}, consider Duplicated or Const")
end
Adapt.adapt_structure(to, x::Active) = Active(adapt(to, x.val))

Expand All @@ -63,6 +65,15 @@ accumulator for gradients (so ``\\partial f / \\partial x`` will be *added to*)
struct Duplicated{T} <: Annotation{T}
val::T
dval::T
@inline Duplicated(x::T1, dx::T1, check::Bool=true) where {T1} = new{T1}(x, dx)
@inline function Duplicated(x::T1, dx::T1, check::Bool=true) where {T1 <: SubArray}
if check
@assert x.indices == dx.indices
@assert x.offset1 == dx.offset1
@assert x.stride1 == dx.stride1
end
new{T1}(x, dx)
end
end
Adapt.adapt_structure(to, x::Duplicated) = Duplicated(adapt(to, x.val), adapt(to, x.dval))

Expand All @@ -75,6 +86,15 @@ the original result and only compute the derivative values.
struct DuplicatedNoNeed{T} <: Annotation{T}
val::T
dval::T
@inline DuplicatedNoNeed(x::T1, dx::T1, check::Bool=true) where {T1} = new{T1}(x, dx)
@inline function DuplicatedNoNeed(x::T1, dx::T1, check::Bool=true) where {T1 <: SubArray}
if check
@assert x.indices == dx.indices
@assert x.offset1 == dx.offset1
@assert x.stride1 == dx.stride1
end
new{T1}(x, dx)
end
end
Adapt.adapt_structure(to, x::DuplicatedNoNeed) = DuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval))

Expand All @@ -87,6 +107,17 @@ for all at once. Argument `∂f_∂xs` should be a tuple of the several values o
struct BatchDuplicated{T,N} <: Annotation{T}
val::T
dval::NTuple{N,T}
@inline BatchDuplicated(x::T1, dx::NTuple{N,T1}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx)
@inline function DuplicatedNoNeed(x::T1, dx::NTuple{N,T1}, check::Bool=true) where {T1 <: SubArray, N}
if check
for dxi in dx
@assert x.indices == dxi.indices
@assert x.offset1 == dxi.offset1
@assert x.stride1 == dxi.stride1
end
end
new{T1, N}(x, dx)
end
end
Adapt.adapt_structure(to, x::BatchDuplicated) = BatchDuplicated(adapt(to, x.val), adapt(to, x.dval))

Expand All @@ -105,6 +136,17 @@ for all at once. Argument `∂f_∂xs` should be a tuple of the several values o
struct BatchDuplicatedNoNeed{T,N} <: Annotation{T}
val::T
dval::NTuple{N,T}
@inline BatchDuplicatedNoNeed(x::T1, dx::NTuple{N,T1}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx)
@inline function DuplicatedNoNeed(x::T1, dx::NTuple{N,T1}, check::Bool=true) where {T1 <: SubArray, N}
if check
for dxi in dx
@assert x.indices == dxi.indices
@assert x.offset1 == dxi.offset1
@assert x.stride1 == dxi.stride1
end
end
new{T1, N}(x, dx)
end
end
batch_size(::BatchDuplicated{T,N}) where {T,N} = N
batch_size(::BatchDuplicatedFunc{T,N}) where {T,N} = N
Expand Down
9 changes: 9 additions & 0 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ function forward(::Const{typeof(f)}, ::Type{<:Const}; kwargs...)
end

@test has_frule_from_sig(Base.signature_type(f, Tuple{}))

data = [1.0, 2.0, 3.0, 4.0]

d = @view data[2:end]
y = @view data[3:end]
@test_throws ErrorException Duplicated(d, y)

@test_throws ErrorException Active(data)
@test_throws ErrorException Active(d)
6 changes: 5 additions & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ function EnzymeRules.forward(
RT::Type{<:Union{Duplicated,BatchDuplicated}},
x::Union{Duplicated,BatchDuplicated},
)
return RT(func.val(x.val), map(func.val, x.dval))
if RT <: BatchDuplicated
return BatchDuplicated(func.val(x.val), map(func.val, x.dval))
else
return Duplicated(func.val(x.val), func.val(x.dval))
end
end

@testset "Batch complex" begin
Expand Down

0 comments on commit 1ff1a6d

Please sign in to comment.