Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into mhauru/distributions-…
Browse files Browse the repository at this point in the history
…integration-tests
  • Loading branch information
mhauru committed Nov 5, 2024
2 parents 4b2e56f + df20f18 commit 177b731
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3257,6 +3257,7 @@ function annotate!(mod, mode)
if haskey(funcs, fname)
for fn in funcs[fname]
push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache"))
push!(parameter_attributes(fn, 1), LLVM.EnumAttribute("nocapture"))
end
end
end
Expand Down Expand Up @@ -3452,7 +3453,7 @@ function annotate!(mod, mode)
LLVM.API.LLVMRemoveEnumAttributeAtIndex(
fn,
reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex),
kind(EnumAttribute("allockind")),
kind(EnumAttribute("allockind", AllocFnKind(AFKE_Alloc).data)),
)
push!(function_attributes(fn), no_escaping_alloc)
push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data))
Expand Down Expand Up @@ -3498,6 +3499,7 @@ function annotate!(mod, mode)
if haskey(funcs, fname)
for fn in funcs[fname]
push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0))
push!(return_attributes(fn), LLVM.EnumAttribute("nonnull", 0))
push!(function_attributes(fn), no_escaping_alloc)
push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress"))
push!(function_attributes(fn), LLVM.EnumAttribute("willreturn"))
Expand Down
15 changes: 15 additions & 0 deletions test/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,18 @@ end
Enzyme.Const(dist)
)
end

@noinline function mc_g(i, _not_used)
k = (0.25)
return (i, k)
end

function mc_f(_not_used)
i = (0.0, 3.9555)
t = mc_g(i, _not_used)
return t[1][2]
end

@testset "Memcopy of constant" begin
@test Enzyme.autodiff(Enzyme.Forward, mc_f, Duplicated(2.7, 1.0))[1] 0.0
end
26 changes: 21 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2303,18 +2303,34 @@ end
@testset "Broadcast noalias" begin

x = ones(30)
autodiff(Reverse, bc0_test_function, Active, Const(x))

@static if VERSION < v"1.11-"
autodiff(Reverse, bc0_test_function, Active, Const(x))
else
# TODO
@test_broken autodiff(Reverse, bc0_test_function, Active, Const(x))
end

x = rand(Float32, 2, 3)
Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x)))
@static if VERSION < v"1.11-"
Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x)))
else
# TODO
@test_broken Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x)))
end

x = rand(Float32, 6, 6, 6, 2)
sc = rand(Float32, 6)
bi = rand(Float32, 6)
Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)),
Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi)))
@static if VERSION < v"1.11-"
Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)),
Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi)))
else
# TODO
@test_broken Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)),
Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi)))
end
end

function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
a1 = 1 / @inbounds poly[1]
E1 = 2*a1
Expand Down

0 comments on commit 177b731

Please sign in to comment.