Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JET integration #261

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,6 +21,7 @@ Adapt = "3.3"
CEnum = "0.4"
Enzyme_jll = "0.0.29"
GPUCompiler = "0.14"
JET = "0.5.10"
LLVM = "4.1"
ObjectFile = "0.3"
julia = "1.6"
25 changes: 13 additions & 12 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ include("compiler.jl")

import .Compiler: CompilationException

include("JET.jl")

# @inline annotate() = ()
# @inline annotate(arg::A, args::Vararg{Any, N}) where {A<:Annotation, N} = (arg, annotate(args...)...)
# @inline annotate(arg, args::Vararg{Any, N}) where N = (Const(arg), annotate(args...)...)
Expand All @@ -116,6 +118,9 @@ import .Compiler: CompilationException
end
end

# annotated args to argtypes
getargtypes(args′) = Tuple{map(@nospecialize(t)->eltype(Core.Typeof(t)), args′)...}

prepare_cc() = ()
prepare_cc(arg::Duplicated, args...) = (arg.val, arg.dval, prepare_cc(args...)...)
prepare_cc(arg::DuplicatedNoNeed, args...) = (arg.val, arg.dval, prepare_cc(args...)...)
Expand Down Expand Up @@ -178,16 +183,16 @@ while ``\\partial f/\\partial b`` will be *added to* `∂f_∂b` (but not return
args′ = annotate(args...)
tt′ = Tuple{map(Core.Typeof, args′)...}
if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
rt = Core.Compiler.return_type(f, getargtypes(args′))
if !allocatedinline(rt)
forward, adjoint = Enzyme.Compiler.thunk(f, #=df=#nothing, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient))
res = forward(args′...)
tape = res[1]
if res[3] isa Base.RefValue
res[3][] += one(eltype(typeof(res[3])))
res3 = res[3]
if res3 isa Base.RefValue
res3[] += one(eltype(res3))
else
res[3] += one(eltype(typeof(res[3])))
res3 += one(eltype(res3))
end
return adjoint(args′..., tape)
end
Expand Down Expand Up @@ -220,18 +225,14 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value.
"""
@inline function autodiff(f::F, args...) where {F}
args′ = annotate(args...)
tt′ = Tuple{map(Core.Typeof, args′)...}
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
rt = Core.Compiler.return_type(f, getargtypes(args′))
A = guess_activity(rt)
autodiff(f, A, args′...)
end

@inline function autodiff(dupf::Duplicated{F}, args...) where {F}
args′ = annotate(args...)
tt′ = Tuple{map(Core.Typeof, args′)...}
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(dupf.val, tt)
rt = Core.Compiler.return_type(dupf.val, getargtypes(args′))
A = guess_activity(rt)
autodiff(dupf, A, args′...)
end
Expand Down Expand Up @@ -280,7 +281,7 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
rt = guess_activity(rt)
autodiff_deferred(f, rt, args′...)
autodiff_deferred(f, rt, args′...)
end

using Adapt
Expand Down
202 changes: 202 additions & 0 deletions src/JET.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# This file defines JET analysis to check Enzyme's auto-differentiability.
# In particularity, the analysis will look for:
# - any active `Union`-typed result
# - (unimplemented yet) dynamic dispatch
# - and more...?

# exports
# -------

export report_enzyme, @report_enzyme, test_enzyme, @test_enzyme

# analyzer
# --------

using JET.JETInterface

struct EnzymeAnalyzer <: AbstractAnalyzer
state::AnalyzerState
cache_key::UInt
end
function EnzymeAnalyzer(; jetconfigs...)
return EnzymeAnalyzer(
AnalyzerState(; jetconfigs...),
__ANALYZER_CACHE_KEY[] += 1,
)
end
const __ANALYZER_CACHE_KEY = Ref{UInt}(0)

JETInterface.AnalyzerState(analyzer::EnzymeAnalyzer) = analyzer.state
function JETInterface.AbstractAnalyzer(analyzer::EnzymeAnalyzer, state::AnalyzerState)
return EnzymeAnalyzer(
state,
analyzer.cache_key,
)
end
JETInterface.ReportPass(analyzer::EnzymeAnalyzer) = EnzymeAnalysisPass()
JETInterface.get_cache_key(analyzer::EnzymeAnalyzer) = analyzer.cache_key

function Core.Compiler.finish(frame::Core.Compiler.InferenceState, analyzer::EnzymeAnalyzer)
ReportPass(analyzer)(UnionResultReport, analyzer, frame)
return Base.@invoke Core.Compiler.finish(frame::Core.Compiler.InferenceState, analyzer::AbstractAnalyzer)
end

struct EnzymeAnalysisPass <: ReportPass end

@reportdef struct UnionResultReport <: InferenceErrorReport
@nospecialize(tt) # ::Type
@nospecialize(rt) # ::Type
end
JETInterface.get_msg(::Type{UnionResultReport}, @nospecialize(args...)) = "potentially active Union result detected"
function JETInterface.print_error_report(io, report::UnionResultReport)
Base.@invoke JETInterface.print_error_report(io, report::InferenceErrorReport)
printstyled(io, "::", report.rt; color = :cyan)
end

# check if this return value is used in the caller and its type is inferred as Union:
# XXX `Core.Compiler.call_result_unused` is a very inaccurate model of Enzyme's activity analysis,
# and so this active Union return check might be very incomplete
function (::EnzymeAnalysisPass)(::Type{UnionResultReport}, analyzer::EnzymeAnalyzer, frame::Core.Compiler.InferenceState)
parent = frame.parent
if isa(parent, Core.Compiler.InferenceState) && !(Core.Compiler.call_result_unused(parent))
rt = frame.bestguess
if isa(rt, Union)
add_new_report!(analyzer, frame.result, UnionResultReport(frame.linfo, frame.linfo.specTypes, rt))
return true
end
end
return false
end

# entry
# -----

import JET: report_call, test_call, get_reports

analyze_autodiff_call(entry, f, ::Type{<:Annotation}, args...) = (@nospecialize; _analyze_autodiff_call(entry, f, args...))
analyze_autodiff_call(entry, f, args...) = (@nospecialize; _analyze_autodiff_call(entry, f, args...))
function _analyze_autodiff_call(entry, f, args...)
@nospecialize f args
args′ = annotate(args...)
tt = getargtypes(args′)
return entry(f, tt; analyzer=EnzymeAnalyzer)
end

function apply_autodiff_args(f, @nospecialize(ex))
if Meta.isexpr(ex, :do)
dof = esc(ex.args[2])
autodiff′, args... = map(esc, ex.args[1].args)
return quote
if $autodiff′ !== autodiff
throw(ArgumentError("@$($f) expects `autodiff(...)` call expression"))
end
$f($dof, $(args...))
end
elseif !(Meta.isexpr(ex, :call) && length(ex.args) ≥ 1)
throw(ArgumentError("@$f expects `autodiff(...)` call expression"))
end
autodiff′, args... = map(esc, ex.args)
return quote
if $autodiff′ !== autodiff
throw(ArgumentError("@$($f) expects `autodiff(...)` call expression"))
end
$f($(args...))
end
end

"""
report_enzyme(args...) -> result::JETCallResult

Analyzes potential problems for Enzyme to auto-differentiate `args`.
`args` should be valid arguments to [`autodiff`](@ref) function, i.e.
the call `autodiff(args...)` should meet the `autodiff` interface.

In particularity, `report_enzyme` detects if there is any potentially active `Union`-typed
result, which confuses Enzymes's code generation.
If such `Union`-typed result is unused anywhere, `report_enzyme` doesn't report it as an issue,
since Enzyme can auto-differentiate it without problem.

Note that this analysis is _not_ complete in terms of covering Enzyme's auto-differentiability --
`report_enzyme` models Enzyme's activity analysis very inaccurately, meaning there may be
some code that Enzyme differentiates without any problem while `report_enzyme` raises an issue.

```julia
julia> union_result(cond, x) = cond ? x : 0
union_result (generic function with 1 method)

julia> report_enzyme(Active, true, Active(1.0)) do cond, x
union_result(cond, x) * x
end
═════ 1 possible error found ═════
┌ @ none:2 Main.union_result(cond, x)
│┌ @ none:1 union_result(::Bool, ::Float64)
││ potentially active Union result detected: union_result(::Bool, ::Float64)::Union{Float64, Int64}
│└──────────

julia> report_enzyme(Active, true, Active(1.0)) do cond, x
union_result(cond, x) # inactive Union-typed result
x * x
end
No errors detected

julia> union_result(cond, x) = cond ? x : zero(x) # fix the Union-typed result
union_result (generic function with 1 method)

julia> report_enzyme(Active, true, Active(1.0)) do cond, x
union_result(cond, x) * x
end
No errors detected
```
"""
report_enzyme(args...) = (@nospecialize; analyze_autodiff_call(report_call, args...))

"""
@report_enzyme autodiff(...)

Takes valid [`autodiff`](@ref) call expression and analyzes potential problems for Enzyme to
auto-differentiate it.

See also [`report_enzyme`](@ref).

```julia
julia> union_result(cond, x) = cond ? x : 0
union_result (generic function with 1 method)

julia> @report_enzyme autodiff(Active, true, Active(1.0)) do cond, x
union_result(cond, x) * x
end
═════ 1 possible error found ═════
┌ @ none:2 Main.union_result(cond, x)
│┌ @ none:1 union_result(::Bool, ::Float64)
││ potentially active Union result detected: union_result(::Bool, ::Float64)::Union{Float64, Int64}
│└──────────
end
```
"""
macro report_enzyme(ex) apply_autodiff_args(report_enzyme, ex) end

# TODO support test configurations?

"""
test_enzyme(args...) -> JETCallResult

Tests `args` can be safely auto-differentiated by Enzyme.jl
`args` should be valid arguments to [`autodiff`](@ref) function, i.e.
the call `autodiff(args...)` should meet the `autodiff` interface.

See also [`@test_enzyme`](@ref), [`report_enzyme`](@ref).
"""
test_enzyme(args...) = (@nospecialize; analyze_autodiff_call(test_call, args...))

"""
@test_enzyme autodiff(...)

Tests the given [`autodiff`](@ref) call can be safely auto-differentiated by Enzyme.
Returns a `Pass` result if it is, a `Fail` result if if contains any potential problems,
or an `Error` result if this macro encounters an unexpected error.
When the test `Fail`s, abstract call stack to each problem location will also be printed
to `stdout`.

See also [`report_enzyme`](@ref).
"""
macro test_enzyme(ex) apply_autodiff_args(test_enzyme, ex) end
Loading