Skip to content

Commit

Permalink
turn active shadow tuple to zero (#2281)
Browse files Browse the repository at this point in the history
* turn active shadow tuple to zero

* 1.10.8

* update test
  • Loading branch information
wsmoses authored Jan 23, 2025
1 parent 18c95ba commit a74ec0d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1140,11 +1140,10 @@ end
) where {Ann,Nargs}
expr = Vector{Expr}(undef, Nargs)
for i = 1:Nargs
if args[i] <: Active
throw(AssertionError("Unsupported Active arg $(args[i])"))
end
@inbounds expr[i] = if args[i] <: Const
:(args[$i].val)
elseif args[i] <: Active
:(Enzyme.make_zero(args[$i].val))
elseif args[i] <: MixedDuplicated
:(args[$i].dval[])
else
Expand All @@ -1170,9 +1169,10 @@ end
for w = 1:width
expr = Vector{Expr}(undef, Nargs)
for i = 1:Nargs
@assert !(args[i] <: Active)
@inbounds expr[i] = if args[i] <: Const
:(args[$i].val)
elseif args[i] <: Active
:(Enzyme.make_zero(args[$i].val))
elseif args[i] <: BatchMixedDuplicated
:(args[$i].dval[$w][])
else
Expand Down
4 changes: 3 additions & 1 deletion test/applyiter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,9 @@ end
data = [[3.0], nothing, 2.0]
ddata = [[0.0], nothing, 0.0]

@test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata))
Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata))
@test ddata[1][1] 2.0
@test ddata[3] 3.0

function mktup3(v)
tup = tuple(v..., v...)
Expand Down
15 changes: 8 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1845,15 +1845,16 @@ end
dR = zeros(6, 6)

@static if VERSION v"1.11-"
elseif VERSION v"1.10.8"
autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR))
@test 1.0 dR[1, 1]
@test 1.0 dR[2, 2]
@test 1.0 dR[3, 3]
@test 1.0 dR[4, 4]
@test 1.0 dR[5, 5]
@test 0.0 dR[6, 6]
else
@test_broken autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR))
# autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR))
# @test 1.0 ≈ dR[1, 1]
# @test 1.0 ≈ dR[2, 2]
# @test 1.0 ≈ dR[3, 3]
# @test 1.0 ≈ dR[4, 4]
# @test 1.0 ≈ dR[5, 5]
# @test 0.0 ≈ dR[6, 6]
end
end

Expand Down

0 comments on commit a74ec0d

Please sign in to comment.