Skip to content

Commit

Permalink
docs: add an Optimization.jl tutorial showcasing lazy data movement
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 14, 2024
1 parent f60db4d commit e7caa61
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ pages = [
"tutorials/beginner/1_Basics.md",
"tutorials/beginner/2_PolynomialFitting.md",
"tutorials/beginner/3_SimpleRNN.md",
"tutorials/beginner/4_SimpleChains.md"
"tutorials/beginner/4_SimpleChains.md",
"tutorials/beginner/5_OptimizationIntegration.md"
],
"Intermediate" => [
"tutorials/intermediate/1_NeuralODE.md",
Expand Down
4 changes: 3 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ export default defineConfig({
{ text: 'Julia & Lux for the Uninitiated', link: '/tutorials/beginner/1_Basics' },
{ text: 'Fitting a Polynomial using MLP', link: '/tutorials/beginner/2_PolynomialFitting' },
{ text: 'Training a Simple LSTM', link: '/tutorials/beginner/3_SimpleRNN' },
{ text: 'MNIST Classification with SimpleChains', link: '/tutorials/beginner/4_SimpleChains' }]
{ text: 'MNIST Classification with SimpleChains', link: '/tutorials/beginner/4_SimpleChains' },
{ text: 'Fitting with Optimization.jl', link: '/tutorials/beginner/5_OptimizationIntegration' },
]
},
{
text: 'Intermediate', collapsed: false, items: [
Expand Down
Binary file added docs/src/public/optimization_integration.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ const beginner = [
caption: "Use SimpleChains.jl as a Backend",
desc: "Learn how to train small neural networks really fast on CPU."
},
{
href: "beginner/5_OptimizationIntegration",
src: "../optimization_integration.png",
caption: "Fitting with Optimization.jl",
desc: "Learn how to use Optimization.jl with Lux (on GPUs)."
},
{
href: "https://luxdl.github.io/Boltz.jl/stable/tutorials/1_GettingStarted",
src: "https://production-media.paperswithcode.com/datasets/ImageNet-0000000008-f2e87edd_Y0fT5zg.jpg",
caption: "Pre-Built Deep Learning Models",
desc: "Use Boltz.jl to load pre-built deep learning and scientific machine learning models."
desc: "Use Boltz.jl to load pre-built DL and SciML models."
}
];
Expand Down
3 changes: 2 additions & 1 deletion docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ const BEGINNER_TUTORIALS = [
"Basics/main.jl" => "CUDA",
"PolynomialFitting/main.jl" => "CUDA",
"SimpleRNN/main.jl" => "CUDA",
"SimpleChains/main.jl" => "CUDA"
"SimpleChains/main.jl" => "CUDA",
"OptimizationIntegration/main.jl" => "CUDA",
]
const INTERMEDIATE_TUTORIALS = [
"NeuralODE/main.jl" => "CUDA",
Expand Down
33 changes: 33 additions & 0 deletions examples/OptimizationIntegration/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

[compat]
CairoMakie = "0.12.10"
ComponentArrays = "0.15.17"
InteractiveUtils = "<0.0.1, 1"
IterTools = "1.10"
Literate = "2.19"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
Optimization = "3.28.0"
OptimizationOptimJL = "0.3.2"
OptimizationOptimisers = "0.2.1"
OrdinaryDiffEqTsit5 = "1.1.0"
Printf = "1.10"
Random = "1.10"
SciMLSensitivity = "7.67.0"
158 changes: 158 additions & 0 deletions examples/OptimizationIntegration/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# # Training Lux Models using Optimization.jl

# Lux's native [Training.TrainState](@ref) is a great API for gradient-based learning of
# neural networks, however, it is geared towards using `Optimisers.jl` as the backend.
# However, often times we want to train the neural networks with other optimization methods
# like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with
# Optimization.jl that provides a simple unified interface to various optimization methods.

# We will base our tutorial on the minibatching tutorial from the official
# [Optimization.jl](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) docs.

# !!! note "Neural ODE"
#
# This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial.
# Please refer to the Neural ODE tutorial for more information.

# ## Imports packages

using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5,
SciMLSensitivity, Random, MLUtils, IterTools, CairoMakie, ComponentArrays, Printf
using LuxCUDA

const gdev = gpu_device()
const cdev = cpu_device()

# ## Generate some training data

function lotka_volterra(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = α * x - β * x * y
du[2] = -δ * y + γ * x * y
return nothing
end

u0 = [1.0f0, 1.0f0]

datasize = 32
tspan = (0.0f0, 2.0f0)

const t = range(tspan[1], tspan[2]; length=datasize)
true_prob = ODEProblem(lotka_volterra, u0, (tspan[1], tspan[2]), [1.5, 1.0, 3.0, 1.0])
const ode_data = Array(solve(true_prob, Tsit5(); saveat=t))

begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1])
lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4)
lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4)
axislegend(ax; position=:lt)
fig
end

# ## Define the DataLoader

# We will define the DataLoader to batch over the data, additionally we will pipe it through
# the `gdev` device to move the data to the GPU on each iteration.

# By default `gdev` will move all objects to the GPU. But we don't want to move the time
# vector to the GPU. So we will wrap it in a struct.
struct TimeWrapper{T}
t::T
end

Base.length(t::TimeWrapper) = length(t.t)

Base.getindex(t::TimeWrapper, i) = TimeWrapper(t.t[i])

dataloader = DataLoader((ode_data, TimeWrapper(t)); batchsize=8) |> gdev

# ## Training the model

# Here we are using different optimization methods for demonstration purposes. This problem
# is trivial enough to not require this.

# Optimization.jl requires an abstract array as the parameters, hence we will construct a
# `ComponentArray` to store the parameters.

# !!! note "Parameter Estimation vs State Estimation"
#
# Optimization.jl performs state estimation, which effectively means for a function
# `f(u, p)`, it is trying to compute the optimal `u` for a given `p`. This terminology
# might be confusing to ML practitioners, since in the ML world, we usually do parameter
# estimation. This effectively means that the `u` in Optimization.jl corresponds to our
# model parameters that is being optimized.

function train_model(dataloader)
model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
ps, st = Lux.setup(Random.default_rng(), model)

ps_ca = ComponentArray(ps) |> gdev
st = st |> gdev

function callback(state, l)
state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-8 ## Terminate if loss is small
end

smodel = StatefulLuxLayer{true}(model, nothing, st)

function loss_adjoint(θ, u_batch, t_batch)
t_batch = t_batch.t
u0 = u_batch[:, 1]
dudt(u, p, t) = smodel(u, p)
prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
pred = convert(AbstractArray, solve(prob, Tsit5(); saveat=t_batch))
return MSELoss()(pred, u_batch)
end

## Define the Optimization Function that takes in the optimization state (our parameters)
## and optimization parameters (nothing in our case) and data from the dataloader and
## returns the loss.
opt_func = OptimizationFunction(
(θ, _, u_batch, t_batch) -> loss_adjoint(θ, u_batch, t_batch),
Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_ca)

nepcohs = 25
res_adam = solve(
opt_prob, Optimisers.Adam(0.001), ncycle(dataloader, nepcohs); callback)

## Let's finetune a bit with L-BFGS
opt_prob = remake(opt_prob; u0=res_adam.u)
res_lbfgs = solve(opt_prob, LBFGS(), ncycle(dataloader, nepcohs); callback)

## Now that we have a good fit, let's train it on the entire dataset without
## Minibatching. We need to do this since ODE solves can lead to accumulated errors if
## the model was trained on individual parts (without a data-shooting approach).
opt_func = OptimizationFunction(
(θ, _) -> loss_adjoint(θ, gdev(ode_data), TimeWrapper(t)),
Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, res_lbfgs.u)

res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)

return StatefulLuxLayer{true}(model, res.u, smodel.st)
end

trained_model = train_model(dataloader)
nothing #hide

# ## Plotting the results

dudt(u, p, t) = trained_model(u, p)
prob = ODEProblem(dudt, gdev(u0), (tspan[1], tspan[2]), trained_model.ps)
sol = solve(prob, Tsit5(); saveat=t)
pred = convert(AbstractArray, sol) |> cdev

begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1])
lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4)
lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4)
lines!(ax, t, pred[1, :]; label=L"\hat{u}_1(t)", color=:blue, linewidth=4)
lines!(ax, t, pred[2, :]; label=L"\hat{u}_2(t)", color=:red, linewidth=4)
axislegend(ax; position=:lt)
fig
end

1 comment on commit e7caa61

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: e7caa61 Previous: f60db4d Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 412125 ns 414500 ns 0.99
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 243958 ns 322250 ns 0.76
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 323708 ns 322708.5 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 739541 ns 741958 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 43923 ns 44250.5 ns 0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1371937.5 ns 1327167 ns 1.03
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 1261354.5 ns 2451688 ns 0.51
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 14008417 ns 14209750 ns 0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 2194708 ns 2193937.5 ns 1.00
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 205355.5 ns 207380 ns 0.99
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1417791 ns 1468292 ns 0.97
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 887292 ns 923959 ns 0.96
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 1812208 ns 1598937.5 ns 1.13
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 2212208 ns 2242395.5 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1733292 ns 1762396 ns 0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1102709 ns 1028250 ns 1.07
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1521208.5 ns 1537583 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3013042 ns 2885833.5 ns 1.04
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 206261 ns 208790 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12142666.5 ns 12117833 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8846250 ns 8811750 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9207354 ns 9165333.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18580625 ns 18605125 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1487526 ns 1497201 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17293999.5 ns 17314916 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13993542 ns 13952000 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14483333.5 ns 14449937 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21828604 ns 21832333 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250797875.5 ns 250356604.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148823333 ns 148503729 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115883354 ns 115663250 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 454090334 ns 452727834 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5479922 ns 5471701 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1224986000 ns 1224679334 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 928765000 ns 932428750 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 827976542 ns 831047479.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1644101834 ns 1654023458 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 31171453 ns 31662494 ns 0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1135178542 ns 1141591625 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 999714166.5 ns 1004360417 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1303828854 ns 1322994750 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1741741896 ns 1741933375 ns 1.00
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1090792 ns 1120833.5 ns 0.97
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1637666 ns 1620917 ns 1.01
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3582708 ns 3462083 ns 1.03
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 786167 ns 779667 ns 1.01
lenet(28, 28, 1, 32)/forward/GPU/CUDA 262029.5 ns 270336.5 ns 0.97
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2971792 ns 2988271 ns 0.99
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4123750 ns 4139875 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 10426437.5 ns 9659916 ns 1.08
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3148229 ns 3132834 ns 1.00
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1092866.5 ns 1134352.5 ns 0.96
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2326979.5 ns 2338166 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1322083 ns 1437021 ns 0.92
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1673459 ns 1669291 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 4208000 ns 4193000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 208009 ns 210459.5 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 19416729 ns 19441042 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16098667 ns 16082770.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 17355625.5 ns 17400416.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 25850083 ns 25866000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1585986 ns 1593435 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 34150479.5 ns 34177125 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 30929208 ns 30976000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 30921458 ns 31151000 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 36974958 ns 36261000 ns 1.02
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4531521 ns 4537333 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2560770.5 ns 2776604 ns 0.92
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2921667 ns 2913645.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 8371042 ns 8378750 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 418970 ns 420670 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 38979334 ns 38891374.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 32095416.5 ns 32306292 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 32177750 ns 32384208 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 52201375 ns 51948083 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2623471.5 ns 2620746.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 88784146 ns 88847729 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 115080917 ns 114070333.5 ns 1.01
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 219472208 ns 226493250 ns 0.97
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 74369625 ns 73885250 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 268757584 ns 268317334 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 156384958 ns 159216084 ns 0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 126657687.5 ns 127078708 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 493203959 ns 492762417 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 7020761 ns 6963353 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1470831958.5 ns 1469208062.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1166912584 ns 1179701333 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1069943646 ns 1064469187.5 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 2001019750 ns 2018298416.5 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 34698558 ns 34585385 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1717061375 ns 1726168042 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1532673563 ns 1532131312.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1763256208 ns 1753217833 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 2235971083 ns 2220540250 ns 1.01
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2033917 ns 2032250 ns 1.00
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 3001750 ns 2850166.5 ns 1.05
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 6995583 ns 7482625 ns 0.93
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2410250.5 ns 2429979 ns 0.99
lenet(28, 28, 1, 128)/forward/GPU/CUDA 277586.5 ns 267353.5 ns 1.04
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9596208.5 ns 9603854 ns 1.00
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 11959208.5 ns 11874437.5 ns 1.01
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 24601937.5 ns 24867021 ns 0.99
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 11737000 ns 11308542 ns 1.04
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1205760 ns 1173785 ns 1.03
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 380715583 ns 380634584 ns 1.00
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 309134958 ns 287745375 ns 1.07
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 240484604 ns 243501229 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 455377687.5 ns 452284375.5 ns 1.01
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4901230.5 ns 5016811.5 ns 0.98
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1156536042 ns 1137459875 ns 1.02
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 934444375 ns 943993333 ns 0.99
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 937955958 ns 898262625 ns 1.04
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 1404766250 ns 1411909416 ns 0.99
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 19140053 ns 18115193 ns 1.06
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1044417 ns 1060437 ns 0.98
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 1665375 ns 2017041.5 ns 0.83
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 4659125 ns 5113542 ns 0.91
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1381458 ns 1366833 ns 1.01
lenet(28, 28, 1, 64)/forward/GPU/CUDA 277017 ns 265207 ns 1.04
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6471291 ns 6505083 ns 0.99
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 13171521 ns 12271187.5 ns 1.07
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 19155167 ns 18806687.5 ns 1.02
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 6084416 ns 6078250 ns 1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1234554.5 ns 1214045 ns 1.02
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70412479 ns 70581646 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43727396 ns 43485459 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39654166 ns 39436292 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132669229.5 ns 132675958 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1946871.5 ns 1863920 ns 1.04
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 356558937 ns 355687833.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 269901750 ns 270693083.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 253901458 ns 254405500.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 543773875 ns 538777458 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 12296125.5 ns 12367452 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 394712500 ns 396200000 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 374705250 ns 402727854 ns 0.93
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 695586875 ns 668679417 ns 1.04
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 709322375 ns 708861625 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1186149833 ns 1187349792 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 826656438 ns 694829104 ns 1.19
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 633369958 ns 629932709 ns 1.01
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 1780054896 ns 1779143271 ns 1.00
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12309879.5 ns 13225818 ns 0.93
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3667291125 ns 3622108083.5 ns 1.01
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2822189750 ns 2828172709 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2708758541 ns 2724737708 ns 0.99
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 5073574916 ns 5083300000 ns 1.00
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49860253.5 ns 49807086.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3405875.5 ns 3420729.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2065250 ns 2074875 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2540250 ns 2525042 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6014604 ns 6011833 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 343548 ns 315086 ns 1.09
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25973124.5 ns 26295500 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 18895979.5 ns 18987458 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19531416.5 ns 19862667 ns 0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 39260291 ns 39218853.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2465534 ns 2478386 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 55438083 ns 55626729.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 82776500 ns 81917708 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 170298791.5 ns 172510354 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 45669667 ns 45569417 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1778750 ns 1782395.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1090500 ns 1093791.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1558104.5 ns 1586291.5 ns 0.98
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3024604.5 ns 3026979 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 213362 ns 213440.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12540583.5 ns 12557083 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9227000 ns 9205917 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9657125 ns 9717709 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 19007583 ns 18945396 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1541996 ns 1545222 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17663416.5 ns 17667958 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14318000 ns 14312292 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14637687.5 ns 14670667 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 22165500 ns 22150709 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70531042 ns 70496583.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43664104 ns 43541375 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39661792 ns 39470417 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132673333 ns 132760312.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1958870.5 ns 1958343 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 360097833 ns 358409083 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 347045875 ns 346583313 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 303749416 ns 304589375 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 730460916 ns 725990125 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13376571.5 ns 13320357 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 417572000 ns 418971104 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 427267750 ns 419729042 ns 1.02
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 741225500 ns 662505333 ns 1.12
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 714873709 ns 715138292 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1698792 ns 1450437 ns 1.17
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1054000 ns 1298979 ns 0.81
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1350854 ns 1344645.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2441916 ns 2365917 ns 1.03
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 591065.5 ns 590150.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 8953166 ns 8684833 ns 1.03
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 13693270.5 ns 12890000 ns 1.06
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 30593292 ns 30836166.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 9833291 ns 9843750 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1490752.5 ns 1473920 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 18112458 ns 17999292 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 17467500.5 ns 16546208 ns 1.06
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 29723417 ns 29181291 ns 1.02
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 14333291 ns 14097584 ns 1.02
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 678667 ns 693250 ns 0.98
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 520708 ns 521417 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1030125 ns 1040750 ns 0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 724041.5 ns 724875 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 48348.5 ns 48072 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1565250 ns 1566292 ns 1.00
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1021042 ns 1002937.5 ns 1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1403604 ns 1370333.5 ns 1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 2276771 ns 2257250 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 240894 ns 238196.5 ns 1.01
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1558958.5 ns 1571020.5 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1066687 ns 1080916 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1686104.5 ns 1541833 ns 1.09
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 2226166 ns 2236209 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3395333 ns 3399875 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2056708 ns 2047875 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2516479.5 ns 2515021 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6008875 ns 6005375 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 288149 ns 286172.5 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24072041.5 ns 24087042 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17187916 ns 17224041.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17151041.5 ns 17292291 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 37508458.5 ns 37522062.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2405382 ns 2407498 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 53630708.5 ns 53768270.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 85883417 ns 83654187.5 ns 1.03
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 169400792 ns 169263021 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 44653333.5 ns 44565333.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250291979 ns 250492042 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148449041 ns 148428250 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115844229 ns 115397479.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 454457542 ns 450610604 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5457486 ns 5443833 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1104315584 ns 1101924667 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 856185292 ns 855192187.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 826815646 ns 827218333.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1779319041 ns 1763706625 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 28848435 ns 29367206 ns 0.98
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1008544916.5 ns 1019223979 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 978883958 ns 945177042 ns 1.04
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1273700875 ns 1303173167 ns 0.98
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1738185250 ns 1739257541.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1308271 ns 1211708 ns 1.08
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 736416 ns 981875 ns 0.75
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 970250.5 ns 948167 ns 1.02
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 1992812.5 ns 2062875 ns 0.97
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 573191 ns 569657 ns 1.01
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 5738375 ns 5819083.5 ns 0.99
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 9161583 ns 4699250 ns 1.95
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 23815937.5 ns 24610750.5 ns 0.97
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 7075209 ns 7096333 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1420790.5 ns 1369164.5 ns 1.04
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 10487208 ns 11390750 ns 0.92
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 10551083 ns 9112562.5 ns 1.16
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 17215062.5 ns 17263667 ns 1.00
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 8751250 ns 8694666.5 ns 1.01
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 407270.5 ns 384000 ns 1.06
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 377167 ns 364688 ns 1.03
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 1964791 ns 2302437.5 ns 0.85
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 88084 ns 89750 ns 0.98
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 28331 ns 27591.5 ns 1.03
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 349125.5 ns 391125 ns 0.89
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 443959 ns 382584 ns 1.16
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4468750 ns 4380375 ns 1.02
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 258666 ns 258417 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 226875.5 ns 220859 ns 1.03
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 379791 ns 421604 ns 0.90
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 474250 ns 411750 ns 1.15
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4250750 ns 4491917 ns 0.95
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 270875 ns 271250 ns 1.00
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 352916 ns 329896 ns 1.07
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 315333 ns 300084 ns 1.05
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 733437.5 ns 750333 ns 0.98
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 54271 ns 54375 ns 1.00
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 28442.5 ns 27841 ns 1.02
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 296333.5 ns 355792 ns 0.83
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 341000 ns 247167 ns 1.38
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 547833 ns 868125 ns 0.63
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 151666 ns 151750 ns 1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 211615 ns 205968 ns 1.03
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 310375 ns 368375 ns 0.84
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 354833 ns 261709 ns 1.36
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 397666.5 ns 714208 ns 0.56
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 150875 ns 151125 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 602928541 ns 601673542 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 428260667 ns 433401687 ns 0.99
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 376195062.5 ns 378552750 ns 0.99
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 872683792 ns 874120625 ns 1.00
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7030918 ns 7030592 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 2004923333 ns 2007087354.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1605834520.5 ns 1632009874.5 ns 0.98
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1586581187 ns 1618542583.5 ns 0.98
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 2623747583 ns 2637429416 ns 0.99
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 25764798 ns 26054721.5 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 524375 ns 523500 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 398542 ns 435895.5 ns 0.91
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 2053208 ns 1828249.5 ns 1.12
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 865312.5 ns 866354 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 47173 ns 47636 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1905417 ns 1763270.5 ns 1.08
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 1742250 ns 2797458.5 ns 0.62
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 15020687.5 ns 14370145.5 ns 1.05
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 2728374.5 ns 2769562.5 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 248474 ns 248789.5 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 1969542 ns 1945916.5 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 1845500 ns 5043500 ns 0.37
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 14844708.5 ns 14572416 ns 1.02
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 2764666.5 ns 2785979.5 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1573416 ns 1374375 ns 1.14
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 947833.5 ns 1189542 ns 0.80
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1240792 ns 1224645.5 ns 1.01
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2252813 ns 2299000 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 585933 ns 583268.5 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5942334 ns 5918791 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 8620916 ns 7147000 ns 1.21
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 25125208 ns 24359584 ns 1.03
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 7308250 ns 7320208 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1387329 ns 1348690.5 ns 1.03
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 13102000 ns 13093542 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 12175645.5 ns 12017167 ns 1.01
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 20369042 ns 20888000 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 10679959 ns 10214417 ns 1.05
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 2500 ns 2375 ns 1.05
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 4917 ns 2500 ns 1.97
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 2875 ns 3333.5 ns 0.86
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2458 ns 2958 ns 0.83
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24933 ns 24628 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7375 ns 7291.5 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 7042 ns 7083 ns 0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7333.5 ns 7333.5 ns 1
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7125 ns 7083 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 213416.5 ns 209898.5 ns 1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8209 ns 8250 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8209 ns 8208 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8500 ns 8375 ns 1.01
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 5916 ns 5958 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10625 ns 10458 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 16167 ns 12937.5 ns 1.25
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 10875 ns 10708 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7187.5 ns 7250 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 25092 ns 24907 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 19916 ns 19875 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 19875 ns 20104.5 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 20333 ns 20125 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 19854.5 ns 20000 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 233641 ns 230594 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 23562.5 ns 23583.5 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 23416 ns 23708 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 23750 ns 23625 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 21208 ns 21333 ns 0.99
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28958 ns 28459 ns 1.02
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28416 ns 28542 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28500 ns 28770.5 ns 0.99
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 46041 ns 45917 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 26269 ns 25803 ns 1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 226146 ns 230250 ns 0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 274895.5 ns 288166 ns 0.95
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4238916.5 ns 4212042 ns 1.01
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 145125 ns 145000 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 211030 ns 207914 ns 1.01
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 337437.5 ns 342187.5 ns 0.99
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 316687.5 ns 333166 ns 0.95
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 800125 ns 411895.5 ns 1.94
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 161750 ns 160646 ns 1.01
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 1917 ns 1750 ns 1.10
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 2250 ns 1791 ns 1.26
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2167 ns 2250 ns 0.96
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1875 ns 1958 ns 0.96
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 23176 ns 23251.5 ns 1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5333 ns 5208 ns 1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5000 ns 5208 ns 0.96
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5417 ns 5500 ns 0.98
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5167 ns 5291 ns 0.98
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 256165.5 ns 245332 ns 1.04
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 11250 ns 11291.5 ns 1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 11417 ns 11375 ns 1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 11375 ns 11458 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 6958 ns 6959 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 79988042 ns 79898667 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 47941125 ns 49104563 ns 0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 45005708 ns 44920792 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 151719833 ns 151542042 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2721409 ns 2713787 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 663734583 ns 665144875 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 409152750 ns 414328875 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 398007625 ns 399605708 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 690336167 ns 687317792 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 14673781 ns 14579874 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 709715271 ns 718439500 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 686379083 ns 685447833 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 980829291 ns 1000305625 ns 0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 997377042 ns 992652792 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.