-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: add an Optimization.jl tutorial showcasing lazy data movement
- Loading branch information
Showing
7 changed files
with
205 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
[deps] | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" | ||
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" | ||
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" | ||
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" | ||
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" | ||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" | ||
|
||
[compat] | ||
CairoMakie = "0.12.10" | ||
ComponentArrays = "0.15.17" | ||
InteractiveUtils = "<0.0.1, 1" | ||
IterTools = "1.10" | ||
Literate = "2.19" | ||
Lux = "1" | ||
LuxCUDA = "0.3.3" | ||
MLUtils = "0.4.4" | ||
Optimization = "3.28.0" | ||
OptimizationOptimJL = "0.3.2" | ||
OptimizationOptimisers = "0.2.1" | ||
OrdinaryDiffEqTsit5 = "1.1.0" | ||
Printf = "1.10" | ||
Random = "1.10" | ||
SciMLSensitivity = "7.67.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# # Training Lux Models using Optimization.jl | ||
|
||
# Lux's native [Training.TrainState](@ref) is a great API for gradient-based learning of | ||
# neural networks, however, it is geared towards using `Optimisers.jl` as the backend. | ||
# However, often times we want to train the neural networks with other optimization methods | ||
# like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with | ||
# Optimization.jl that provides a simple unified interface to various optimization methods. | ||
|
||
# We will base our tutorial on the minibatching tutorial from the official | ||
# [Optimization.jl](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) docs. | ||
|
||
# !!! note "Neural ODE" | ||
# | ||
# This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial. | ||
# Please refer to the Neural ODE tutorial for more information. | ||
|
||
# ## Imports packages | ||
|
||
using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5, | ||
SciMLSensitivity, Random, MLUtils, IterTools, CairoMakie, ComponentArrays, Printf | ||
using LuxCUDA | ||
|
||
const gdev = gpu_device() | ||
const cdev = cpu_device() | ||
|
||
# ## Generate some training data | ||
|
||
function lotka_volterra(du, u, p, t) | ||
x, y = u | ||
α, β, δ, γ = p | ||
du[1] = α * x - β * x * y | ||
du[2] = -δ * y + γ * x * y | ||
return nothing | ||
end | ||
|
||
u0 = [1.0f0, 1.0f0] | ||
|
||
datasize = 32 | ||
tspan = (0.0f0, 2.0f0) | ||
|
||
const t = range(tspan[1], tspan[2]; length=datasize) | ||
true_prob = ODEProblem(lotka_volterra, u0, (tspan[1], tspan[2]), [1.5, 1.0, 3.0, 1.0]) | ||
const ode_data = Array(solve(true_prob, Tsit5(); saveat=t)) | ||
|
||
begin | ||
fig = Figure() | ||
ax = CairoMakie.Axis(fig[1, 1]) | ||
lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4) | ||
lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4) | ||
axislegend(ax; position=:lt) | ||
fig | ||
end | ||
|
||
# ## Define the DataLoader | ||
|
||
# We will define the DataLoader to batch over the data, additionally we will pipe it through | ||
# the `gdev` device to move the data to the GPU on each iteration. | ||
|
||
# By default `gdev` will move all objects to the GPU. But we don't want to move the time | ||
# vector to the GPU. So we will wrap it in a struct. | ||
struct TimeWrapper{T} | ||
t::T | ||
end | ||
|
||
Base.length(t::TimeWrapper) = length(t.t) | ||
|
||
Base.getindex(t::TimeWrapper, i) = TimeWrapper(t.t[i]) | ||
|
||
dataloader = DataLoader((ode_data, TimeWrapper(t)); batchsize=8) |> gdev | ||
|
||
# ## Training the model | ||
|
||
# Here we are using different optimization methods for demonstration purposes. This problem | ||
# is trivial enough to not require this. | ||
|
||
# Optimization.jl requires an abstract array as the parameters, hence we will construct a | ||
# `ComponentArray` to store the parameters. | ||
|
||
# !!! note "Parameter Estimation vs State Estimation" | ||
# | ||
# Optimization.jl performs state estimation, which effectively means for a function | ||
# `f(u, p)`, it is trying to compute the optimal `u` for a given `p`. This terminology | ||
# might be confusing to ML practitioners, since in the ML world, we usually do parameter | ||
# estimation. This effectively means that the `u` in Optimization.jl corresponds to our | ||
# model parameters that is being optimized. | ||
|
||
function train_model(dataloader) | ||
model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2)) | ||
ps, st = Lux.setup(Random.default_rng(), model) | ||
|
||
ps_ca = ComponentArray(ps) |> gdev | ||
st = st |> gdev | ||
|
||
function callback(state, l) | ||
state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\n" state.iter l | ||
return l < 1e-8 ## Terminate if loss is small | ||
end | ||
|
||
smodel = StatefulLuxLayer{true}(model, nothing, st) | ||
|
||
function loss_adjoint(θ, u_batch, t_batch) | ||
t_batch = t_batch.t | ||
u0 = u_batch[:, 1] | ||
dudt(u, p, t) = smodel(u, p) | ||
prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ) | ||
pred = convert(AbstractArray, solve(prob, Tsit5(); saveat=t_batch)) | ||
return MSELoss()(pred, u_batch) | ||
end | ||
|
||
## Define the Optimization Function that takes in the optimization state (our parameters) | ||
## and optimization parameters (nothing in our case) and data from the dataloader and | ||
## returns the loss. | ||
opt_func = OptimizationFunction( | ||
(θ, _, u_batch, t_batch) -> loss_adjoint(θ, u_batch, t_batch), | ||
Optimization.AutoZygote()) | ||
opt_prob = OptimizationProblem(opt_func, ps_ca) | ||
|
||
nepcohs = 25 | ||
res_adam = solve( | ||
opt_prob, Optimisers.Adam(0.001), ncycle(dataloader, nepcohs); callback) | ||
|
||
## Let's finetune a bit with L-BFGS | ||
opt_prob = remake(opt_prob; u0=res_adam.u) | ||
res_lbfgs = solve(opt_prob, LBFGS(), ncycle(dataloader, nepcohs); callback) | ||
|
||
## Now that we have a good fit, let's train it on the entire dataset without | ||
## Minibatching. We need to do this since ODE solves can lead to accumulated errors if | ||
## the model was trained on individual parts (without a data-shooting approach). | ||
opt_func = OptimizationFunction( | ||
(θ, _) -> loss_adjoint(θ, gdev(ode_data), TimeWrapper(t)), | ||
Optimization.AutoZygote()) | ||
opt_prob = OptimizationProblem(opt_func, res_lbfgs.u) | ||
|
||
res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback) | ||
|
||
return StatefulLuxLayer{true}(model, res.u, smodel.st) | ||
end | ||
|
||
trained_model = train_model(dataloader) | ||
nothing #hide | ||
|
||
# ## Plotting the results | ||
|
||
dudt(u, p, t) = trained_model(u, p) | ||
prob = ODEProblem(dudt, gdev(u0), (tspan[1], tspan[2]), trained_model.ps) | ||
sol = solve(prob, Tsit5(); saveat=t) | ||
pred = convert(AbstractArray, sol) |> cdev | ||
|
||
begin | ||
fig = Figure() | ||
ax = CairoMakie.Axis(fig[1, 1]) | ||
lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4) | ||
lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4) | ||
lines!(ax, t, pred[1, :]; label=L"\hat{u}_1(t)", color=:blue, linewidth=4) | ||
lines!(ax, t, pred[2, :]; label=L"\hat{u}_2(t)", color=:red, linewidth=4) | ||
axislegend(ax; position=:lt) | ||
fig | ||
end |
e7caa61
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lux Benchmarks
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s)
412125
ns414500
ns0.99
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s)
243958
ns322250
ns0.76
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s)
323708
ns322708.5
ns1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s)
739541
ns741958
ns1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA
43923
ns44250.5
ns0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s)
1371937.5
ns1327167
ns1.03
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s)
1261354.5
ns2451688
ns0.51
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s)
14008417
ns14209750
ns0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s)
2194708
ns2193937.5
ns1.00
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA
205355.5
ns207380
ns0.99
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s)
1417791
ns1468292
ns0.97
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s)
887292
ns923959
ns0.96
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s)
1812208
ns1598937.5
ns1.13
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s)
2212208
ns2242395.5
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
1733292
ns1762396
ns0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1102709
ns1028250
ns1.07
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1521208.5
ns1537583
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
3013042
ns2885833.5
ns1.04
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA
206261
ns208790
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
12142666.5
ns12117833
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
8846250
ns8811750
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
9207354
ns9165333.5
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
18580625
ns18605125
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1487526
ns1497201
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
17293999.5
ns17314916
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
13993542
ns13952000
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
14483333.5
ns14449937
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
21828604
ns21832333
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
250797875.5
ns250356604.5
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
148823333
ns148503729
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
115883354
ns115663250
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
454090334
ns452727834
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA
5479922
ns5471701
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1224986000
ns1224679334
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
928765000
ns932428750
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
827976542
ns831047479.5
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
1644101834
ns1654023458
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
31171453
ns31662494
ns0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1135178542
ns1141591625
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
999714166.5
ns1004360417
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1303828854
ns1322994750
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
1741741896
ns1741933375
ns1.00
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s)
1090792
ns1120833.5
ns0.97
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s)
1637666
ns1620917
ns1.01
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s)
3582708
ns3462083
ns1.03
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s)
786167
ns779667
ns1.01
lenet(28, 28, 1, 32)/forward/GPU/CUDA
262029.5
ns270336.5
ns0.97
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s)
2971792
ns2988271
ns0.99
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s)
4123750
ns4139875
ns1.00
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s)
10426437.5
ns9659916
ns1.08
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s)
3148229
ns3132834
ns1.00
lenet(28, 28, 1, 32)/zygote/GPU/CUDA
1092866.5
ns1134352.5
ns0.96
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
2326979.5
ns2338166
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1322083
ns1437021
ns0.92
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1673459
ns1669291
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
4208000
ns4193000
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA
208009
ns210459.5
ns0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
19416729
ns19441042
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
16098667
ns16082770.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
17355625.5
ns17400416.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
25850083
ns25866000
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1585986
ns1593435
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
34150479.5
ns34177125
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
30929208
ns30976000
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
30921458
ns31151000
ns0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
36974958
ns36261000
ns1.02
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
4531521
ns4537333
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2560770.5
ns2776604
ns0.92
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2921667
ns2913645.5
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
8371042
ns8378750
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA
418970
ns420670
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
38979334
ns38891374.5
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
32095416.5
ns32306292
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
32177750
ns32384208
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
52201375
ns51948083
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2623471.5
ns2620746.5
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
88784146
ns88847729
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
115080917
ns114070333.5
ns1.01
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
219472208
ns226493250
ns0.97
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
74369625
ns73885250
ns1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
268757584
ns268317334
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
156384958
ns159216084
ns0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
126657687.5
ns127078708
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
493203959
ns492762417
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA
7020761
ns6963353
ns1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1470831958.5
ns1469208062.5
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
1166912584
ns1179701333
ns0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
1069943646
ns1064469187.5
ns1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
2001019750
ns2018298416.5
ns0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
34698558
ns34585385
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1717061375
ns1726168042
ns0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
1532673563
ns1532131312.5
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1763256208
ns1753217833
ns1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
2235971083
ns2220540250
ns1.01
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s)
2033917
ns2032250
ns1.00
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s)
3001750
ns2850166.5
ns1.05
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s)
6995583
ns7482625
ns0.93
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s)
2410250.5
ns2429979
ns0.99
lenet(28, 28, 1, 128)/forward/GPU/CUDA
277586.5
ns267353.5
ns1.04
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s)
9596208.5
ns9603854
ns1.00
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s)
11959208.5
ns11874437.5
ns1.01
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s)
24601937.5
ns24867021
ns0.99
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s)
11737000
ns11308542
ns1.04
lenet(28, 28, 1, 128)/zygote/GPU/CUDA
1205760
ns1173785
ns1.03
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s)
380715583
ns380634584
ns1.00
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s)
309134958
ns287745375
ns1.07
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s)
240484604
ns243501229
ns0.99
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s)
455377687.5
ns452284375.5
ns1.01
vgg16(32, 32, 3, 32)/forward/GPU/CUDA
4901230.5
ns5016811.5
ns0.98
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s)
1156536042
ns1137459875
ns1.02
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s)
934444375
ns943993333
ns0.99
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s)
937955958
ns898262625
ns1.04
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s)
1404766250
ns1411909416
ns0.99
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA
19140053
ns18115193
ns1.06
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s)
1044417
ns1060437
ns0.98
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s)
1665375
ns2017041.5
ns0.83
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s)
4659125
ns5113542
ns0.91
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s)
1381458
ns1366833
ns1.01
lenet(28, 28, 1, 64)/forward/GPU/CUDA
277017
ns265207
ns1.04
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s)
6471291
ns6505083
ns0.99
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s)
13171521
ns12271187.5
ns1.07
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s)
19155167
ns18806687.5
ns1.02
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s)
6084416
ns6078250
ns1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA
1234554.5
ns1214045
ns1.02
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
70412479
ns70581646
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
43727396
ns43485459
ns1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
39654166
ns39436292
ns1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
132669229.5
ns132675958
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA
1946871.5
ns1863920
ns1.04
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
356558937
ns355687833.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
269901750
ns270693083.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
253901458
ns254405500.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
543773875
ns538777458
ns1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
12296125.5
ns12367452
ns0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
394712500
ns396200000
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
374705250
ns402727854
ns0.93
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
695586875
ns668679417
ns1.04
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
709322375
ns708861625
ns1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s)
1186149833
ns1187349792
ns1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s)
826656438
ns694829104
ns1.19
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s)
633369958
ns629932709
ns1.01
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s)
1780054896
ns1779143271
ns1.00
vgg16(32, 32, 3, 128)/forward/GPU/CUDA
12309879.5
ns13225818
ns0.93
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s)
3667291125
ns3622108083.5
ns1.01
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s)
2822189750
ns2828172709
ns1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s)
2708758541
ns2724737708
ns0.99
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s)
5073574916
ns5083300000
ns1.00
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA
49860253.5
ns49807086.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
3405875.5
ns3420729.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2065250
ns2074875
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2540250
ns2525042
ns1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
6014604
ns6011833
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA
343548
ns315086
ns1.09
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
25973124.5
ns26295500
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
18895979.5
ns18987458
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
19531416.5
ns19862667
ns0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
39260291
ns39218853.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2465534
ns2478386
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
55438083
ns55626729.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
82776500
ns81917708
ns1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
170298791.5
ns172510354
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
45669667
ns45569417
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
1778750
ns1782395.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1090500
ns1093791.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1558104.5
ns1586291.5
ns0.98
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
3024604.5
ns3026979
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA
213362
ns213440.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
12540583.5
ns12557083
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
9227000
ns9205917
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
9657125
ns9717709
ns0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
19007583
ns18945396
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1541996
ns1545222
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
17663416.5
ns17667958
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
14318000
ns14312292
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
14637687.5
ns14670667
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
22165500
ns22150709
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
70531042
ns70496583.5
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
43664104
ns43541375
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
39661792
ns39470417
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
132673333
ns132760312.5
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA
1958870.5
ns1958343
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
360097833
ns358409083
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
347045875
ns346583313
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
303749416
ns304589375
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
730460916
ns725990125
ns1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
13376571.5
ns13320357
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
417572000
ns418971104
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
427267750
ns419729042
ns1.02
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
741225500
ns662505333
ns1.12
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
714873709
ns715138292
ns1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s)
1698792
ns1450437
ns1.17
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s)
1054000
ns1298979
ns0.81
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s)
1350854
ns1344645.5
ns1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s)
2441916
ns2365917
ns1.03
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA
591065.5
ns590150.5
ns1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s)
8953166
ns8684833
ns1.03
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s)
13693270.5
ns12890000
ns1.06
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s)
30593292
ns30836166.5
ns0.99
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s)
9833291
ns9843750
ns1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA
1490752.5
ns1473920
ns1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s)
18112458
ns17999292
ns1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s)
17467500.5
ns16546208
ns1.06
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s)
29723417
ns29181291
ns1.02
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s)
14333291
ns14097584
ns1.02
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s)
678667
ns693250
ns0.98
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s)
520708
ns521417
ns1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s)
1030125
ns1040750
ns0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s)
724041.5
ns724875
ns1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA
48348.5
ns48072
ns1.01
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s)
1565250
ns1566292
ns1.00
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s)
1021042
ns1002937.5
ns1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s)
1403604
ns1370333.5
ns1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s)
2276771
ns2257250
ns1.01
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA
240894
ns238196.5
ns1.01
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s)
1558958.5
ns1571020.5
ns0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s)
1066687
ns1080916
ns0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s)
1686104.5
ns1541833
ns1.09
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s)
2226166
ns2236209
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
3395333
ns3399875
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2056708
ns2047875
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2516479.5
ns2515021
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
6008875
ns6005375
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA
288149
ns286172.5
ns1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
24072041.5
ns24087042
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
17187916
ns17224041.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
17151041.5
ns17292291
ns0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
37508458.5
ns37522062.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2405382
ns2407498
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
53630708.5
ns53768270.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
85883417
ns83654187.5
ns1.03
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
169400792
ns169263021
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
44653333.5
ns44565333.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
250291979
ns250492042
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
148449041
ns148428250
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
115844229
ns115397479.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
454457542
ns450610604
ns1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA
5457486
ns5443833
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1104315584
ns1101924667
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
856185292
ns855192187.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
826815646
ns827218333.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
1779319041
ns1763706625
ns1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
28848435
ns29367206
ns0.98
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1008544916.5
ns1019223979
ns0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
978883958
ns945177042
ns1.04
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1273700875
ns1303173167
ns0.98
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
1738185250
ns1739257541.5
ns1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s)
1308271
ns1211708
ns1.08
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s)
736416
ns981875
ns0.75
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s)
970250.5
ns948167
ns1.02
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s)
1992812.5
ns2062875
ns0.97
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA
573191
ns569657
ns1.01
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s)
5738375
ns5819083.5
ns0.99
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s)
9161583
ns4699250
ns1.95
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s)
23815937.5
ns24610750.5
ns0.97
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s)
7075209
ns7096333
ns1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA
1420790.5
ns1369164.5
ns1.04
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s)
10487208
ns11390750
ns0.92
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s)
10551083
ns9112562.5
ns1.16
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s)
17215062.5
ns17263667
ns1.00
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s)
8751250
ns8694666.5
ns1.01
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s)
407270.5
ns384000
ns1.06
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s)
377167
ns364688
ns1.03
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s)
1964791
ns2302437.5
ns0.85
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s)
88084
ns89750
ns0.98
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA
28331
ns27591.5
ns1.03
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s)
349125.5
ns391125
ns0.89
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s)
443959
ns382584
ns1.16
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s)
4468750
ns4380375
ns1.02
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s)
258666
ns258417
ns1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA
226875.5
ns220859
ns1.03
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s)
379791
ns421604
ns0.90
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s)
474250
ns411750
ns1.15
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s)
4250750
ns4491917
ns0.95
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s)
270875
ns271250
ns1.00
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s)
352916
ns329896
ns1.07
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s)
315333
ns300084
ns1.05
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s)
733437.5
ns750333
ns0.98
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s)
54271
ns54375
ns1.00
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA
28442.5
ns27841
ns1.02
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s)
296333.5
ns355792
ns0.83
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s)
341000
ns247167
ns1.38
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s)
547833
ns868125
ns0.63
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s)
151666
ns151750
ns1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA
211615
ns205968
ns1.03
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s)
310375
ns368375
ns0.84
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s)
354833
ns261709
ns1.36
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s)
397666.5
ns714208
ns0.56
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s)
150875
ns151125
ns1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s)
602928541
ns601673542
ns1.00
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s)
428260667
ns433401687
ns0.99
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s)
376195062.5
ns378552750
ns0.99
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s)
872683792
ns874120625
ns1.00
vgg16(32, 32, 3, 64)/forward/GPU/CUDA
7030918
ns7030592
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s)
2004923333
ns2007087354.5
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s)
1605834520.5
ns1632009874.5
ns0.98
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s)
1586581187
ns1618542583.5
ns0.98
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s)
2623747583
ns2637429416
ns0.99
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA
25764798
ns26054721.5
ns0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s)
524375
ns523500
ns1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s)
398542
ns435895.5
ns0.91
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s)
2053208
ns1828249.5
ns1.12
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s)
865312.5
ns866354
ns1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA
47173
ns47636
ns0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s)
1905417
ns1763270.5
ns1.08
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s)
1742250
ns2797458.5
ns0.62
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s)
15020687.5
ns14370145.5
ns1.05
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s)
2728374.5
ns2769562.5
ns0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA
248474
ns248789.5
ns1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s)
1969542
ns1945916.5
ns1.01
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s)
1845500
ns5043500
ns0.37
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s)
14844708.5
ns14572416
ns1.02
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s)
2764666.5
ns2785979.5
ns0.99
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s)
1573416
ns1374375
ns1.14
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s)
947833.5
ns1189542
ns0.80
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s)
1240792
ns1224645.5
ns1.01
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s)
2252813
ns2299000
ns0.98
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA
585933
ns583268.5
ns1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s)
5942334
ns5918791
ns1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s)
8620916
ns7147000
ns1.21
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s)
25125208
ns24359584
ns1.03
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s)
7308250
ns7320208
ns1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA
1387329
ns1348690.5
ns1.03
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s)
13102000
ns13093542
ns1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s)
12175645.5
ns12017167
ns1.01
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s)
20369042
ns20888000
ns0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s)
10679959
ns10214417
ns1.05
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s)
2500
ns2375
ns1.05
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s)
4917
ns2500
ns1.97
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s)
2875
ns3333.5
ns0.86
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s)
2458
ns2958
ns0.83
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA
24933
ns24628
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s)
7375
ns7291.5
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s)
7042
ns7083
ns0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s)
7333.5
ns7333.5
ns1
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s)
7125
ns7083
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA
213416.5
ns209898.5
ns1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s)
8209
ns8250
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s)
8209
ns8208
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s)
8500
ns8375
ns1.01
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s)
5916
ns5958
ns0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s)
10625
ns10458
ns1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s)
16167
ns12937.5
ns1.25
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s)
10875
ns10708
ns1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s)
7187.5
ns7250
ns0.99
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA
25092
ns24907
ns1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s)
19916
ns19875
ns1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s)
19875
ns20104.5
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s)
20333
ns20125
ns1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s)
19854.5
ns20000
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA
233641
ns230594
ns1.01
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s)
23562.5
ns23583.5
ns1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s)
23416
ns23708
ns0.99
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s)
23750
ns23625
ns1.01
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s)
21208
ns21333
ns0.99
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s)
28958
ns28459
ns1.02
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s)
28416
ns28542
ns1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s)
28500
ns28770.5
ns0.99
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s)
46041
ns45917
ns1.00
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA
26269
ns25803
ns1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s)
226146
ns230250
ns0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s)
274895.5
ns288166
ns0.95
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s)
4238916.5
ns4212042
ns1.01
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s)
145125
ns145000
ns1.00
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA
211030
ns207914
ns1.01
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s)
337437.5
ns342187.5
ns0.99
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s)
316687.5
ns333166
ns0.95
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s)
800125
ns411895.5
ns1.94
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s)
161750
ns160646
ns1.01
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s)
1917
ns1750
ns1.10
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s)
2250
ns1791
ns1.26
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s)
2167
ns2250
ns0.96
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s)
1875
ns1958
ns0.96
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA
23176
ns23251.5
ns1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s)
5333
ns5208
ns1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s)
5000
ns5208
ns0.96
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s)
5417
ns5500
ns0.98
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s)
5167
ns5291
ns0.98
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA
256165.5
ns245332
ns1.04
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s)
11250
ns11291.5
ns1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s)
11417
ns11375
ns1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s)
11375
ns11458
ns0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s)
6958
ns6959
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
79988042
ns79898667
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
47941125
ns49104563
ns0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
45005708
ns44920792
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
151719833
ns151542042
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA
2721409
ns2713787
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
663734583
ns665144875
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
409152750
ns414328875
ns0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
398007625
ns399605708
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
690336167
ns687317792
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
14673781
ns14579874
ns1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
709715271
ns718439500
ns0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
686379083
ns685447833
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
980829291
ns1000305625
ns0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
997377042
ns992652792
ns1.00
This comment was automatically generated by workflow using github-action-benchmark.