Skip to content

Commit

Permalink
properly handle jac_prototype base on sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
mjohnson541 committed Mar 10, 2024
1 parent 8e117d3 commit f422946
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/Reactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@ function Reactor(domain::T, y0::Array{T1,1}, tspan::Tuple, interfaces::Z=[]; p::
prectmp = ilu(W, τ=tau)
preccache = Ref(prectmp)

if sparsity > 0.8
jac_prototype = J

Check warning on line 58 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L58

Added line #L58 was not covered by tests
else
jac_prototype = nothing
end

if (forwardsensitivities || !forwarddiff) && domain isa Union{ConstantTPDomain,ConstantVDomain,ConstantPDomain,ParametrizedTPDomain,ParametrizedVDomain,ParametrizedPDomain,ConstantTVDomain,ParametrizedTConstantVDomain,ConstantTAPhiDomain}
if !forwardsensitivities
odefcn = ODEFunction(dydt; jac=jacy!, paramjac=jacp!)
else
odefcn = ODEFunction(dydt; paramjac=jacp!)
end
else
odefcn = ODEFunction(dydt; jac=jacyforwarddiff!, paramjac=jacpforwarddiff!, jac_prototype=float.(J)) #jac_prototype is not needed/used for Sundials solvers but maybe needed for Julia solvers
odefcn = ODEFunction(dydt; jac=jacyforwarddiff!, paramjac=jacpforwarddiff!, jac_prototype=jac_prototype) #jac_prototype is not needed/used for Sundials solvers but maybe needed for Julia solvers

Check warning on line 70 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L70

Added line #L70 was not covered by tests
end
if forwardsensitivities
ode = ODEForwardSensitivityProblem(odefcn, y0, tspan, p)
Expand All @@ -78,9 +84,9 @@ function Reactor(domain::T, y0::Array{T1,1}, tspan::Tuple, interfaces::Z=[]; p::
sys = modelingtoolkitize(ode)
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])
if (forwardsensitivities || !forwarddiff) && domain isa Union{ConstantTPDomain,ConstantVDomain,ConstantPDomain,ParametrizedTPDomain,ParametrizedVDomain,ParametrizedPDomain,ConstantTVDomain,ParametrizedTConstantVDomain,ConstantTAPhiDomain}
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!)
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!, jac_prototype=jac_prototype)

Check warning on line 87 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L87

Added line #L87 was not covered by tests
else
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacpforwarddiff!)
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacpforwarddiff!, jac_prototype=jac_prototype)

Check warning on line 89 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L89

Added line #L89 was not covered by tests
end
if forwardsensitivities
ode = ODEForwardSensitivityProblem(odefcn, y0, tspan, p)
Expand All @@ -90,7 +96,7 @@ function Reactor(domain::T, y0::Array{T1,1}, tspan::Tuple, interfaces::Z=[]; p::
end
return Reactor(domain, interfaces, y0, tspan, p, ode, recsolver, forwardsensitivities, forwarddiff, modelingtoolkit, tau, precsundials, psetupsundials, precsjulia)
end
function Reactor(domains::T, y0s::W1, tspan::W2, interfaces::Z=Tuple(), ps::X=SciMLBase.NullParameters(); forwardsensitivities=false, modelingtoolkit=false, tau=1e-3) where {T<:Tuple,W1<:Tuple,Z,X,W2}
function Reactor(domains::T, y0s::W1, tspan::W2, interfaces::Z=Tuple(), ps::X=SciMLBase.NullParameters(); forwardsensitivities=false, forwarddiff=false, modelingtoolkit=false, tau=1e-3) where {T<:Tuple,W1<:Tuple,Z,X,W2}
#adjust indexing
y0 = zeros(sum(length(y) for y in y0s))
Nvars = 0
Expand Down Expand Up @@ -217,18 +223,24 @@ function Reactor(domains::T, y0s::W1, tspan::W2, interfaces::Z=Tuple(), ps::X=Sc
prectmp = ilu(W, τ=tau)
preccache = Ref(prectmp)

if sparsity > 0.8
jac_prototype = J

Check warning on line 227 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L227

Added line #L227 was not covered by tests
else
jac_prototype = nothing
end

if forwardsensitivities
odefcn = ODEFunction(dydt; paramjac=jacp!)
odefcn = ODEFunction(dydt; paramjac=jacp!, jac_prototype=jac_prototype)

Check warning on line 233 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L233

Added line #L233 was not covered by tests
ode = ODEForwardSensitivityProblem(odefcn, y0, tspan, p)
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES)
if modelingtoolkit
sys = modelingtoolkitize(ode)
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!)
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!, jac_prototype=jac_prototype)

Check warning on line 239 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L239

Added line #L239 was not covered by tests
ode = ODEForwardSensitivityProblem(odefcn, y0, tspan, p)
end
else
odefcn = ODEFunction(dydt; jac=jacy!, paramjac=jacp!, jac_prototype=float.(J))
odefcn = ODEFunction(dydt; jac=jacy!, paramjac=jacp!, jac_prototype=jac_prototype)
ode = ODEProblem(odefcn, y0, tspan, p)
if sparsity > 0.8 #empirical threshold to use preconditioner
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES, prec=precsundials, psetup=psetupsundials, prec_side=1)
Expand All @@ -238,7 +250,7 @@ function Reactor(domains::T, y0s::W1, tspan::W2, interfaces::Z=Tuple(), ps::X=Sc
if modelingtoolkit
sys = modelingtoolkitize(ode)
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!)
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!, jac_prototype=jac_prototype)

Check warning on line 253 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L253

Added line #L253 was not covered by tests
ode = ODEProblem(odefcn, y0, tspan, p)
end
end
Expand All @@ -256,7 +268,6 @@ function Reactor(domain::T, y0unlumped::Array{W1,1}, tspan::Tuple, reducedmodelm
dydt(dy::X, y::T, p::V, t::Q) where {X,T,Q,V} = dydtreactor!(dy, y, t, domain, interfaces, reducedmodelmappings, reducedmodelcache, p=p)
jacy!(J::Q2, y::T, p::V, t::Q) where {Q2,T,Q,V} = jacobianyforwarddiff!(J, y, p, t, domain, interfaces, reducedmodelmappings, reducedmodelcache)
jacp!(J::Q2, y::T, p::V, t::Q) where {Q2,T,Q,V} = jacobianpforwarddiff!(J, y, p, t, domain, interfaces, reducedmodelmappings, reducedmodelcache)

#y0 in Y space
y0 = zeros(length(reducedmodelmappings.reducedindexes) + length(reducedmodelmappings.lumpedgroupmapping) + length(domain.thermovariabledict))
@inbounds @views y0[1:end-length(domain.thermovariabledict)-length(reducedmodelmappings.lumpedgroupmapping)] .= y0unlumped[reducedmodelmappings.reducedindexes]
Expand Down Expand Up @@ -294,7 +305,13 @@ function Reactor(domain::T, y0unlumped::Array{W1,1}, tspan::Tuple, reducedmodelm
prectmp = ilu(W, τ=tau)
preccache = Ref(prectmp)

odefcn = ODEFunction(dydt; jac=jacy!, paramjac=jacp!, jac_prototype=float.(J)) #jac_prototype is not needed/used for Sundials solvers but maybe needed for Julia solvers
if sparsity > 0.8
jac_prototype = J

Check warning on line 309 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L309

Added line #L309 was not covered by tests
else
jac_prototype = nothing
end

odefcn = ODEFunction(dydt; jac=jacy!, paramjac=jacp!, jac_prototype=jac_prototype) #jac_prototype is not needed/used for Sundials solvers but maybe needed for Julia solvers

if forwardsensitivities
ode = ODEForwardSensitivityProblem(odefcn, y0, tspan, p)
Expand All @@ -310,7 +327,7 @@ function Reactor(domain::T, y0unlumped::Array{W1,1}, tspan::Tuple, reducedmodelm
if modelingtoolkit
sys = modelingtoolkitize(ode)
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!)
odefcn = ODEFunction(dydt; jac=jac, paramjac=jacp!, jac_prototype=jac_prototype)

Check warning on line 330 in src/Reactor.jl

View check run for this annotation

Codecov / codecov/patch

src/Reactor.jl#L330

Added line #L330 was not covered by tests
if forwardsensitivities
ode = ODEForwardSensitivityProblem(odefcn, y0, tspan, p)
else
Expand Down

0 comments on commit f422946

Please sign in to comment.