Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for eigen-solvers #2264

Open
MasonProtter opened this issue Jan 13, 2025 · 4 comments
Open

Support for eigen-solvers #2264

MasonProtter opened this issue Jan 13, 2025 · 4 comments

Comments

@MasonProtter
Copy link

This currently doesn't work:

julia> using Enzyme

julia> let 
           autodiff(Forward, Duplicated(1.0, 1.0)) do x
               M = [x    1+x
                    1+x' x^2]
               
               sum(eigvecs(M))
           end
       end
ERROR: 
No forward mode derivative found for ejlstr$dsyevr_64_$libblastrampoline.so.5
 at context:   call void @"ejlstr$dsyevr_64_$libblastrampoline.so.5"(i8* noundef nonnull %5, i8* noundef nonnull %6, i8* noundef nonnull %7, i8* noundef nonnull %9, i64 %174, i8* noundef nonnull %11, i8* noundef nonnull %13, i8* noundef nonnull %15, i8* noundef nonnull %17, i8* noundef nonnull %19, i8* noundef nonnull %21, i64 noundef %149, i64 %157, i64 %175, i8* noundef nonnull %23, i64 %150, i64 %176, i8* noundef nonnull %25, i64 %177, i8* noundef nonnull %27, i8* noundef nonnull %4, i64 noundef 1, i64 noundef 1, i64 noundef 1) #141 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %173, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %169, { i8*, {} addrspace(10)* } %107, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %165, { i8*, {} addrspace(10)* } %156, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, { i8*, {} addrspace(10)* } %161, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !262

Stacktrace:
 [1] syevr!
   @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:5397


Stacktrace:
  [1] syevr!
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:5397
  [2] eigen!
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/symmetriceigen.jl:8 [inlined]
  [3] #_eigen#96
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:250 [inlined]
  [4] fwddiffejulia___eigen_96_12524wrap
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:0
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
  [6] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
  [7] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
  [8] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::LinearAlgebra.var"##_eigen#96", df::Nothing, primal_1::Bool, shadow_1_1::Nothing, primal_2::Bool, shadow_2_1::Nothing, primal_3::typeof(LinearAlgebra.eigsortby), shadow_3_1::Nothing, primal_4::typeof(LinearAlgebra._eigen), shadow_4_1::Nothing, primal_5::Matrix{…}, shadow_5_1::Matrix{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/jitrules.jl:303
  [9] _eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:247 [inlined]
 [10] #eigen#94
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:239 [inlined]
 [11] eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:238 [inlined]
 [12] eigvecs
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:274 [inlined]
 [13] #1
    @ ./REPL[2]:6 [inlined]
 [14] fwddiffejulia__1_7197_inner_1wrap
    @ ./REPL[2]:0
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
 [16] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
 [17] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
 [18] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:654 [inlined]
 [19] autodiff(mode::ForwardMode{false, FFIABI, true, false}, f::Const{var"#1#2"}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:544
 [20] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:516 [inlined]
 [21] autodiff(f::Function, m::ForwardMode{false, FFIABI, false, false}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:1019
 [22] top-level scope
    @ REPL[2]:2
Some type information was truncated. Use `show(err)` to see complete types.

It'd be really nice to be able to differentiate through eigen-solver calls.

@wsmoses
Copy link
Member

wsmoses commented Jan 13, 2025

@michel2323 potentially another one for the blas rules?

@MasonProtter
Copy link
Author

So I tried today to make an EnzymeRule for this based on the one in ChainRules.jl, but I'm running into some trouble caused by eigen being type unstable. Any advice on how to deal with type unstable forward rules @wsmoses?

using Enzyme, LinearAlgebra
using LinearAlgebra: BlasFloat

import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

function forward(config::FwdConfig,
                 func::Const{typeof(eigen!)},
                 ::Type{<:Duplicated},
                 A::Duplicated; kwargs...)
    A, ΔA = A.val, A.dval
    if ishermitian(A)
        error("Not yet implemented")
    end
    # adapted from Chainrules.jl 
    F = eigen!(A; kwargs...)::Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}
    λ, V = F.values, F.vectors
    tmp = V \ ΔA
    ∂K = tmp * V
    ∂Kdiag = @view ∂K[diagind(∂K)]
    ∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag)
    ∂K ./= transpose(λ) .- λ
    fill!(∂Kdiag, 0)
    ∂V = mul!(tmp, V, ∂K)
    _eigen_norm_phase_fwd!(∂V, A, V)
    ∂F = Eigen(∂λ, ∂V)::Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}
    Duplicated(F, ∂F)
end

function _eigen_norm_phase_fwd!(∂V, A, V)
    # From Chainrules.jl 
    @inbounds for i in axes(V, 2)
        v, ∂v = @views V[:, i], ∂V[:, i]
        # account for unit normalization
        ∂c_norm = -realdot(v, ∂v)
        if eltype(V) <: Real
            ∂c = ∂c_norm
        else
            # account for rotation of largest element to real
            k = _findrealmaxabs2(v)
            ∂c_phase = -imag(∂v[k]) / real(v[k])
            ∂c = complex(∂c_norm, ∂c_phase)
        end
        ∂v .+= v .* ∂c
    end
    return ∂V
end

# From https://github.com/JuliaMath/RealDot.jl/blob/main/src/RealDot.jl
@inline realdot(x, y) = real(LinearAlgebra.dot(x, y))
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
@inline realdot(x::Real, y::Number) = x * real(y)
@inline realdot(x::Number, y::Real) = real(x) * y
@inline realdot(x::Real, y::Real) = x * y

# From ChainTules.jl
function _findrealmaxabs2(x)
    amax = abs2(first(x))
    imax = 1
    @inbounds for i in 2:length(x)
        xi = x[i]
        !isreal(xi) && continue
        a = abs2(xi)
        a < amax && continue
        amax, imax = a, i
    end
    return imax
end
julia> let 
           autodiff(Forward, Duplicated(1.0, 1.0)) do x
               M = [x-im    1+x
                    1-x x^2]
               λ, V = eigen(M)
               sum(V) - sum(λ)
           end
       end
ERROR: Enzyme execution failed.
Enzyme: incorrect return type of prima/shadow forward custom rule - FwdConfigWidth{1, true, true, false} Duplicated{Union{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}, Eigen{ComplexF64, Float64, Matrix{ComplexF64}, Vector{Float64}}}} Type[Const{typeof(Core.kwcall)}, Const{typeof(eigen!)}, Duplicated{Matrix{ComplexF64}}] want just shadow type Duplicated{Union{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}, Eigen{ComplexF64, Float64, Matrix{ComplexF64}, Vector{Float64}}}} found Duplicated{Eigen{ComplexF64, ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}}}

Stacktrace:
  [1] #_eigen#96
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:252 [inlined]
  [2] fwddiffejulia___eigen_96_47518wrap
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:0
  [3] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
  [4] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
  [5] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
  [6] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::LinearAlgebra.var"##_eigen#96", df::Nothing, primal_1::Bool, shadow_1_1::Nothing, primal_2::Bool, shadow_2_1::Nothing, primal_3::typeof(LinearAlgebra.eigsortby), shadow_3_1::Nothing, primal_4::typeof(LinearAlgebra._eigen), shadow_4_1::Nothing, primal_5::Matrix{…}, shadow_5_1::Matrix{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/jitrules.jl:303
  [7] _eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:247 [inlined]
  [8] #eigen#94
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:239 [inlined]
  [9] eigen
    @ ~/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/eigen.jl:238 [inlined]
 [10] #23
    @ ./REPL[6]:5 [inlined]
 [11] fwddiffejulia__23_47406_inner_1wrap
    @ ./REPL[6]:0
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
 [13] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
 [14] ForwardModeThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4766 [inlined]
 [15] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:654 [inlined]
 [16] autodiff(mode::ForwardMode{false, FFIABI, true, false}, f::Const{var"#23#24"}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:544
 [17] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:516 [inlined]
 [18] autodiff(f::Function, m::ForwardMode{false, FFIABI, false, false}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:1019
 [19] top-level scope
    @ REPL[6]:2
Some type information was truncated. Use `show(err)` to see complete types.

@hz-xiaxz
Copy link

Hi, I'm interested in this feature, too. I tried ArnoldiMethod.jl which is native in julia, but doesn't succeed. The code is as follow (sorry I used Claude to generate this, because I'm purely beginner to this package)

using ArnoldiMethod
using LinearAlgebra
using SparseArrays
using Enzyme

function compute_eigenvalues(diag_elements, subdiag_elements, nev)
    n = length(diag_elements)
    
    A = spdiagm(
        -1 => subdiag_elements,
        0 => diag_elements,
        1 => subdiag_elements
    )
    
    decomp, history = partialschur(A, nev=nev, tol=1e-6, which=:SR)
    
    eigenvalues = decomp.eigenvalues
    
    return real.(eigenvalues)  
end

function eigenvalue_sum(diag_val, subdiag_val, n, nev)
    diag_elements = fill(diag_val, n)
    subdiag_elements = fill(subdiag_val, n-1)
    eigenvalues = compute_eigenvalues(diag_elements, subdiag_elements, nev)
    return sum(eigenvalues)
end

function minimal_example()
    println("=== Minimal Working Example for Enzyme.jl with ArnoldiMethod.jl ===")
    
    n = 10      
    nev = 2
    
    diag_val = 2.0
    subdiag_val = -1.0
    
    sum_val = eigenvalue_sum(diag_val, subdiag_val, n, nev)
    println("Sum of $nev smallest eigenvalues: $sum_val")
    
    println("\n--- CORRECT USAGE WITH DUPLICATED ---")
    # Compute gradient with respect to diagonal value using correct forward mode syntax
    diag_val_dub = Enzyme.Duplicated(diag_val, 1.0)
    diag_grad = Enzyme.autodiff(Forward, eigenvalue_sum, Const, diag_val_dub, Const(subdiag_val), Const(n), Const(nev))[1]
    
    # Compute gradient with respect to subdiagonal value
    subdiag_val_dub = Enzyme.Duplicated(subdiag_val, 1.0)
    subdiag_grad = Enzyme.autodiff(Forward, eigenvalue_sum, Const, Const(diag_val), subdiag_val_dub, Const(n), Const(nev))[1]
    
    println("Gradient w.r.t diagonal value: $diag_grad")
    println("Gradient w.r.t subdiagonal value: $subdiag_grad")
end

The program stops with some warnings

julia> include("arnoldi_enzyme_example.jl")
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:61
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:61
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:61
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:61
=== Minimal Working Example for Enzyme.jl with ArnoldiMethod.jl ===
Sum of 3 smallest eigenvalues: 1.0887855192180709

--- CORRECT USAGE WITH DUPLICATED ---
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:61
ERROR: LoadError: AssertionError: i == Int(length(LLVM.elements(ty))) + 1
Stacktrace:
  [1] zero_single_allocation(builder::LLVM.IRBuilder, jlType::DataType, LLVMType::LLVM.LLVMType, nobj::LLVM.Value, zeroAll::Bool, idx::LLVM.Value; write_barrier::Bool, atomic::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:846
  [2] zero_single_allocation
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:762 [inlined]
  [3] create_recursive_stores(B::LLVM.IRBuilder, Ty::DataType, prev::LLVM.Value)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:611
  [4] shadow_alloc_rewrite(V::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, Orig::Ptr{LLVM.API.LLVMOpaqueValue}, idx::UInt64, prev::Ptr{LLVM.API.LLVMOpaqueValue}, used::UInt8)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:702
  [5] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/.julia/packages/Enzyme/QsaeA/src/api.jl:334
  [6] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:1745
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:4550
  [8] codegen
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:3353 [inlined]
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:5410
 [10] _thunk
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:5410 [inlined]
 [11] cached_compilation
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:5462 [inlined]
 [12] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:5573
 [13] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:5758
 [14] autodiff
    @ ~/.julia/packages/Enzyme/QsaeA/src/Enzyme.jl:640 [inlined]
 [15] autodiff
    @ ~/.julia/packages/Enzyme/QsaeA/src/Enzyme.jl:524 [inlined]
 [16] minimal_example()
    @ Main ~/.julia/dev/arnoldi_enzyme_example.jl:56
 [17] top-level scope
    @ ~/.julia/dev/arnoldi_enzyme_example.jl:82
 [18] include(fname::String)
    @ Main ./sysimg.jl:38
 [19] top-level scope
    @ REPL[7]:1
in expression starting at /home/hzxiaxz/.julia/dev/arnoldi_enzyme_example.jl:82
Some type information was truncated. Use `show(err)` to see complete types.

@hz-xiaxz
Copy link

Another curious thing is, I don't think there is any BLAS operation in ArnoldiMethods.jl (if not, please correct me.) Why is there a fallback BLAS warning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants