-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: add a PINN tutorial with nested AD
- Loading branch information
Showing
13 changed files
with
269 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,3 +40,6 @@ LocalPreferences.toml | |
data/ | ||
|
||
benchmarks/results | ||
|
||
# Generated by tutorials | ||
pinn_nested_ad.gif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
# # Training a PINN on 2D PDE | ||
|
||
# In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the | ||
# system from [NeuralPDE Tutorials](https://docs.sciml.ai/NeuralPDE/stable/tutorials/gpu/). | ||
# However, we will be using our custom loss function and use nested AD capabilities of | ||
# Lux.jl. | ||
|
||
# This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to | ||
# the package: [NeuralPDE.jl](https://github.com/SciML/NeuralPDE.jl). | ||
|
||
# ## Package Imports | ||
|
||
using ADTypes, Lux, Optimisers, Zygote, Random, Printf, Statistics, MLUtils, OnlineStats, | ||
CairoMakie | ||
using LuxCUDA | ||
|
||
CUDA.allowscalar(false) | ||
|
||
const gdev = gpu_device() | ||
const cdev = cpu_device() | ||
|
||
# ## Problem Definition | ||
|
||
# Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem | ||
# with first order derivatives, so that we can compute the gradients of the loss using | ||
# 2nd order AD. | ||
|
||
# TODO: Equations for the PDE | ||
|
||
# ## Define the Neural Networks | ||
|
||
# All the networks take 3 input variables and output a scalar value. Here, we will define a | ||
# a wrapper over the 3 networks, so that we can train them using | ||
# [`Training.TrainState`](@ref). | ||
|
||
struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)} | ||
u::U | ||
v::V | ||
w::W | ||
end | ||
|
||
function create_mlp(act, hidden_dims) | ||
return Chain( | ||
Dense(3 => hidden_dims, act), | ||
Dense(hidden_dims => hidden_dims, act), | ||
Dense(hidden_dims => hidden_dims, act), | ||
Dense(hidden_dims => 1) | ||
) | ||
end | ||
|
||
function PINN(; hidden_dims::Int=32) | ||
return PINN( | ||
create_mlp(tanh, hidden_dims), | ||
create_mlp(tanh, hidden_dims), | ||
create_mlp(tanh, hidden_dims) | ||
) | ||
end | ||
|
||
# ## Define the Loss Functions | ||
|
||
# We will define a custom loss function to compute the loss using 2nd order AD. We | ||
# will use the following loss function | ||
|
||
@views function physics_informed_loss_function( | ||
u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray) | ||
∂u_∂xyt = only(Zygote.gradient(sum ∘ u, xyt)) | ||
∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :] | ||
∂v_∂x = only(Zygote.gradient(sum ∘ v, xyt))[1:1, :] | ||
v_xyt = v(xyt) | ||
∂w_∂y = only(Zygote.gradient(sum ∘ w, xyt))[2:2, :] | ||
w_xyt = w(xyt) | ||
return ( | ||
mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) + | ||
mean(abs2, v_xyt .- ∂u_∂x) + | ||
mean(abs2, w_xyt .- ∂u_∂y) | ||
) | ||
end | ||
|
||
# Additionally, we need to compute the loss wrt the boundary conditions. | ||
|
||
function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray) | ||
return MSELoss()(u(xyt), target) | ||
end | ||
|
||
function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc)) | ||
u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u) | ||
v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v) | ||
w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w) | ||
physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt) | ||
data_loss = mse_loss_function(u_net, target_data, xyt) | ||
bc_loss = mse_loss_function(u_net, target_bc, xyt_bc) | ||
loss = physics_loss + data_loss + bc_loss | ||
return ( | ||
loss, | ||
(; u=u_net.st, v=v_net.st, w=w_net.st), | ||
(; physics_loss, data_loss, bc_loss) | ||
) | ||
end | ||
|
||
# ## Generate the Data | ||
|
||
# We will generate some random data to train the model on. We will take data on a square | ||
# spatial and temporal domain $x \in [0, 2]$, $y \in [0, 2]$, and $t \in [0, 2]$. Typically, | ||
# you want to be smarter about the sampling process, but for the sake of simplicity, we will | ||
# skip that. | ||
|
||
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t) | ||
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :]) | ||
|
||
begin | ||
grid_len = 16 | ||
|
||
grid = range(0.0f0, 2.0f0; length=grid_len) | ||
xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))]) | ||
|
||
target_data = reshape(analytical_solution(xyt), 1, :) | ||
|
||
bc_len = 512 | ||
|
||
x = collect(range(0.0f0, 2.0f0; length=bc_len)) | ||
y = collect(range(0.0f0, 2.0f0; length=bc_len)) | ||
t = collect(range(0.0f0, 2.0f0; length=bc_len)) | ||
|
||
xyt_bc = hcat( | ||
stack((x, y, zeros(Float32, bc_len)); dims=1), | ||
stack((zeros(Float32, bc_len), y, t); dims=1), | ||
stack((ones(Float32, bc_len) .* 2, y, t); dims=1), | ||
stack((x, zeros(Float32, bc_len), t); dims=1), | ||
stack((x, ones(Float32, bc_len) .* 2, t); dims=1) | ||
) | ||
target_bc = reshape(analytical_solution(xyt_bc), 1, :) | ||
|
||
min_target_bc, max_target_bc = extrema(target_bc) | ||
min_data, max_data = extrema(target_data) | ||
min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc) | ||
|
||
xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt)) | ||
xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc)) | ||
target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val) | ||
target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val) | ||
end | ||
nothing #hide | ||
|
||
# ## Training | ||
|
||
function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0, | ||
maxiters::Int=50000, hidden_dims::Int=32) | ||
rng = Random.default_rng() | ||
Random.seed!(rng, seed) | ||
|
||
pinn = PINN(; hidden_dims) | ||
ps, st = Lux.setup(rng, pinn) |> gdev | ||
|
||
bc_dataloader = DataLoader((xyt_bc, target_bc); batchsize=32, shuffle=true) |> gdev | ||
pde_dataloader = DataLoader((xyt, target_data); batchsize=32, shuffle=true) |> gdev | ||
|
||
train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0)) | ||
lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0) | ||
|
||
total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple( | ||
_ -> Lag(Float32, 32), 4) | ||
|
||
iter = 1 | ||
for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip( | ||
Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)) | ||
Optimisers.adjust!(train_state, lr(iter)) | ||
|
||
_, loss, stats, train_state = Training.single_train_step!( | ||
AutoZygote(), loss_function, ( | ||
xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch), | ||
train_state) | ||
|
||
fit!(total_loss_tracker, loss) | ||
fit!(physics_loss_tracker, stats.physics_loss) | ||
fit!(data_loss_tracker, stats.data_loss) | ||
fit!(bc_loss_tracker, stats.bc_loss) | ||
|
||
mean_loss = mean(OnlineStats.value(total_loss_tracker)) | ||
mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker)) | ||
mean_data_loss = mean(OnlineStats.value(data_loss_tracker)) | ||
mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker)) | ||
|
||
isnan(loss) && throw(ArgumentError("NaN Loss Detected")) | ||
|
||
if iter % 500 == 1 || iter == maxiters | ||
@printf "Iteration: [%5d / %5d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \ | ||
(%.9f) \t Data Loss: %.9f (%.9f) \t BC \ | ||
Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss | ||
end | ||
|
||
iter += 1 | ||
iter ≥ maxiters && break | ||
end | ||
|
||
return StatefulLuxLayer{true}( | ||
pinn, cdev(train_state.parameters), cdev(train_state.states)) | ||
end | ||
|
||
trained_model = train_model(xyt, target_data, xyt_bc, target_bc) | ||
trained_u = Lux.testmode(StatefulLuxLayer{true}( | ||
trained_model.model.u, trained_model.ps.u, trained_model.st.u)) | ||
nothing #hide | ||
|
||
# ## Visualizing the Results | ||
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0 | ||
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))]) | ||
|
||
u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts)) | ||
|
||
grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid)) | ||
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts)) | ||
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val | ||
|
||
begin | ||
fig = Figure() | ||
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y") | ||
errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)] | ||
Colorbar(fig[1, 2]; limits=extrema(stack(errs))) | ||
|
||
CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i | ||
ax.title = "Abs. Predictor Error | Time: $(ts[i])" | ||
err = errs[i] | ||
contour!(ax, xs, ys, err; levels=10, linewidth=2) | ||
heatmap!(ax, xs, ys, err) | ||
return fig | ||
end | ||
|
||
fig | ||
end | ||
nothing #hide | ||
|
||
#  |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3ca41c8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lux Benchmarks
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s)
412583
ns412125
ns1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s)
322979.5
ns243958
ns1.32
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s)
243687.5
ns323708
ns0.75
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s)
738583
ns739541
ns1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA
43341
ns43923
ns0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s)
1331541.5
ns1371937.5
ns0.97
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s)
2407792
ns1261354.5
ns1.91
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s)
16439208.5
ns14008417
ns1.17
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s)
2194563
ns2194708
ns1.00
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA
205021
ns205355.5
ns1.00
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s)
1426479.5
ns1417791
ns1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s)
895625
ns887292
ns1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s)
1543917
ns1812208
ns0.85
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s)
2206208.5
ns2212208
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
1777646
ns1733292
ns1.03
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1078167
ns1102709
ns0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1530958
ns1521208.5
ns1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
3007208
ns3013042
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA
207880
ns206261
ns1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
12178459
ns12142666.5
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
8815750
ns8846250
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
9208125
ns9207354
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
18565750
ns18580625
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1492124
ns1487526
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
17280708
ns17293999.5
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
13973458
ns13993542
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
14487354.5
ns14483333.5
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
21838146
ns21828604
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
249950041
ns250797875.5
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
148180333
ns148823333
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
116724083
ns115883354
ns1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
447231500
ns454090334
ns0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA
5449958
ns5479922
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1223396291
ns1224986000
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
929732416
ns928765000
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
832528750
ns827976542
ns1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
1633536917
ns1644101834
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
31232077
ns31171453
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1127648666
ns1135178542
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
1002243833.5
ns999714166.5
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1330111333.5
ns1303828854
ns1.02
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
1732141146
ns1741741896
ns0.99
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s)
1037166
ns1090792
ns0.95
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s)
1626521
ns1637666
ns0.99
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s)
3777708
ns3582708
ns1.05
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s)
781583
ns786167
ns0.99
lenet(28, 28, 1, 32)/forward/GPU/CUDA
262013.5
ns262029.5
ns1.00
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s)
3044041.5
ns2971792
ns1.02
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s)
4097042
ns4123750
ns0.99
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s)
11116208
ns10426437.5
ns1.07
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s)
3145292
ns3148229
ns1.00
lenet(28, 28, 1, 32)/zygote/GPU/CUDA
1092823.5
ns1092866.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
2334666.5
ns2326979.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1419459
ns1322083
ns1.07
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1571666
ns1673459
ns0.94
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
4190125
ns4208000
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA
208094.5
ns208009
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
19396000
ns19416729
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
16089229
ns16098667
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
17212895.5
ns17355625.5
ns0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
25854646
ns25850083
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1594895
ns1585986
ns1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
34213875
ns34150479.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
31031917
ns30929208
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
31076583
ns30921458
ns1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
36708292
ns36974958
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
4528125
ns4531521
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2779625
ns2560770.5
ns1.09
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2669750
ns2921667
ns0.91
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
8386375
ns8371042
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA
420922
ns418970
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
38785750
ns38979334
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
32129375
ns32095416.5
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
32277375
ns32177750
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
51825125
ns52201375
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2627381
ns2623471.5
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
88427125
ns88784146
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
113463375
ns115080917
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
227295542
ns219472208
ns1.04
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
74279083
ns74369625
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
268165959
ns268757584
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
158761895.5
ns156384958
ns1.02
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
123751562.5
ns126657687.5
ns0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
487404084
ns493203959
ns0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA
6976350
ns7020761
ns0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1472349770.5
ns1470831958.5
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
1171105895.5
ns1166912584
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
1066855854.5
ns1069943646
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
2006883229.5
ns2001019750
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
34648809
ns34698558
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1717718833
ns1717061375
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
1535202521
ns1532673563
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1878248417
ns1763256208
ns1.07
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
2205912250
ns2235971083
ns0.99
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s)
2014792
ns2033917
ns0.99
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s)
3001875
ns3001750
ns1.00
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s)
8115875
ns6995583
ns1.16
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s)
2432666
ns2410250.5
ns1.01
lenet(28, 28, 1, 128)/forward/GPU/CUDA
266572
ns277586.5
ns0.96
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s)
9287083.5
ns9596208.5
ns0.97
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s)
12062125
ns11959208.5
ns1.01
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s)
25629896
ns24601937.5
ns1.04
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s)
11743625.5
ns11737000
ns1.00
lenet(28, 28, 1, 128)/zygote/GPU/CUDA
1194959.5
ns1205760
ns0.99
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s)
385030062.5
ns380715583
ns1.01
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s)
288360083
ns309134958
ns0.93
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s)
255891729
ns240484604
ns1.06
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s)
454057438
ns455377687.5
ns1.00
vgg16(32, 32, 3, 32)/forward/GPU/CUDA
4834603
ns4901230.5
ns0.99
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s)
1159548125
ns1156536042
ns1.00
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s)
928639750
ns934444375
ns0.99
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s)
1041025834
ns937955958
ns1.11
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s)
1400301917
ns1404766250
ns1.00
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA
17860514
ns19140053
ns0.93
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s)
1063417
ns1044417
ns1.02
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s)
2031167
ns1665375
ns1.22
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s)
6215979
ns4659125
ns1.33
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s)
1289687.5
ns1381458
ns0.93
lenet(28, 28, 1, 64)/forward/GPU/CUDA
274190
ns277017
ns0.99
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s)
6281917
ns6471291
ns0.97
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s)
12390375
ns13171521
ns0.94
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s)
21407437
ns19155167
ns1.12
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s)
6091125
ns6084416
ns1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA
1242182
ns1234554.5
ns1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
70428792
ns70412479
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
43611000
ns43727396
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
39667125
ns39654166
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
132442542
ns132669229.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA
1932803.5
ns1946871.5
ns0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
355601563
ns356558937
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
270115458
ns269901750
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
253762937.5
ns253901458
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
535226770.5
ns543773875
ns0.98
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
12278246
ns12296125.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
399133667
ns394712500
ns1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
394300875
ns374705250
ns1.05
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
679147666.5
ns695586875
ns0.98
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
710284375
ns709322375
ns1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s)
1194640667
ns1186149833
ns1.01
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s)
689604666
ns826656438
ns0.83
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s)
645000312
ns633369958
ns1.02
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s)
1774100021
ns1780054896
ns1.00
vgg16(32, 32, 3, 128)/forward/GPU/CUDA
12543961
ns12309879.5
ns1.02
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s)
3679223834
ns3667291125
ns1.00
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s)
2826435875
ns2822189750
ns1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s)
2707509167
ns2708758541
ns1.00
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s)
5055415292
ns5073574916
ns1.00
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA
49466162
ns49860253.5
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
3397917
ns3405875.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2077646
ns2065250
ns1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2508042
ns2540250
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
6018583.5
ns6014604
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA
330593.5
ns343548
ns0.96
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
25938625.5
ns25973124.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
18943166.5
ns18895979.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
19554000
ns19531416.5
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
39254458
ns39260291
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2474870
ns2465534
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
55544542
ns55438083
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
82396771
ns82776500
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
173728042
ns170298791.5
ns1.02
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
45516333
ns45669667
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
1779250
ns1778750
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1102875
ns1090500
ns1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1584416
ns1558104.5
ns1.02
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
3020250
ns3024604.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA
214596.5
ns213362
ns1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
12529625
ns12540583.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
9208687.5
ns9227000
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
9658437
ns9657125
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
18982292
ns19007583
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1539988
ns1541996
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
17638625
ns17663416.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
14351250
ns14318000
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
14600292
ns14637687.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
22151625
ns22165500
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
70460291
ns70531042
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
43523750
ns43664104
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
39539833
ns39661792
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
132477750
ns132673333
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA
1881252.5
ns1958870.5
ns0.96
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
359873333.5
ns360097833
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
345894479
ns347045875
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
305415375
ns303749416
ns1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
725056958
ns730460916
ns0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
13380753
ns13376571.5
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
421101688
ns417572000
ns1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
421140000
ns427267750
ns0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
748051292
ns741225500
ns1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
715116625
ns714873709
ns1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s)
1695875
ns1698792
ns1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s)
1355249.5
ns1054000
ns1.29
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s)
1159396
ns1350854
ns0.86
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s)
2409792
ns2441916
ns0.99
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA
588933.5
ns591065.5
ns1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s)
8963979.5
ns8953166
ns1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s)
13000041
ns13693270.5
ns0.95
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s)
33199749.5
ns30593292
ns1.09
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s)
9835041
ns9833291
ns1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA
1481667.5
ns1490752.5
ns0.99
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s)
17747854.5
ns18112458
ns0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s)
17252604.5
ns17467500.5
ns0.99
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s)
31103916.5
ns29723417
ns1.05
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s)
14355042
ns14333291
ns1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s)
669500.5
ns678667
ns0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s)
556667
ns520708
ns1.07
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s)
1058916.5
ns1030125
ns1.03
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s)
725958
ns724041.5
ns1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA
48545
ns48348.5
ns1.00
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s)
1480334
ns1565250
ns0.95
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s)
1045958
ns1021042
ns1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s)
1649187
ns1403604
ns1.17
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s)
2243312
ns2276771
ns0.99
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA
241138.5
ns240894
ns1.00
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s)
1523125
ns1558958.5
ns0.98
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s)
1079125
ns1066687
ns1.01
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s)
1484062.5
ns1686104.5
ns0.88
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s)
2259292
ns2226166
ns1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
3401000.5
ns3395333
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2066583.5
ns2056708
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2508167
ns2516479.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
6010687.5
ns6008875
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA
286726
ns288149
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
24062625
ns24072041.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
17164833
ns17187916
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
17141292
ns17151041.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
37492334
ns37508458.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2402302
ns2405382
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
53536916
ns53630708.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
83532937.5
ns85883417
ns0.97
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
171297229.5
ns169400792
ns1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
44566417
ns44653333.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
249899167
ns250291979
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
147932291
ns148449041
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
116470479.5
ns115844229
ns1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
450194500
ns454457542
ns0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA
5441310
ns5457486
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1101316834
ns1104315584
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
855731854
ns856185292
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
827979937.5
ns826815646
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
1754256625
ns1779319041
ns0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
28790457
ns28848435
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1014882938
ns1008544916.5
ns1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
964872792
ns978883958
ns0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1306006041
ns1273700875
ns1.03
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
1738222270.5
ns1738185250
ns1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s)
1311167
ns1308271
ns1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s)
957750
ns736416
ns1.30
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s)
703375
ns970250.5
ns0.72
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s)
1939625
ns1992812.5
ns0.97
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA
571080
ns573191
ns1.00
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s)
6006770.5
ns5738375
ns1.05
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s)
6329521
ns9161583
ns0.69
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s)
25490292
ns23815937.5
ns1.07
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s)
7082792
ns7075209
ns1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA
1359778.5
ns1420790.5
ns0.96
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s)
11590167
ns10487208
ns1.11
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s)
10239645.5
ns10551083
ns0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s)
18128834
ns17215062.5
ns1.05
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s)
8225354.5
ns8751250
ns0.94
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s)
362875
ns407270.5
ns0.89
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s)
363208
ns377167
ns0.96
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s)
3032958
ns1964791
ns1.54
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s)
87687.5
ns88084
ns1.00
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA
28003
ns28331
ns0.99
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s)
388708
ns349125.5
ns1.11
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s)
440375
ns443959
ns0.99
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s)
4703354
ns4468750
ns1.05
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s)
259209
ns258666
ns1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA
220116.5
ns226875.5
ns0.97
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s)
419708.5
ns379791
ns1.11
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s)
470625
ns474250
ns0.99
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s)
4962583
ns4250750
ns1.17
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s)
270875
ns270875
ns1
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s)
307959
ns352916
ns0.87
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s)
298292
ns315333
ns0.95
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s)
764833
ns733437.5
ns1.04
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s)
54917
ns54271
ns1.01
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA
27854
ns28442.5
ns0.98
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s)
352750
ns296333.5
ns1.19
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s)
336000
ns341000
ns0.99
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s)
887229.5
ns547833
ns1.62
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s)
151687.5
ns151666
ns1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA
205379
ns211615
ns0.97
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s)
366770.5
ns310375
ns1.18
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s)
350334
ns354833
ns0.99
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s)
470875
ns397666.5
ns1.18
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s)
151188
ns150875
ns1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s)
606032542
ns602928541
ns1.01
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s)
429552104
ns428260667
ns1.00
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s)
384048291
ns376195062.5
ns1.02
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s)
874614500
ns872683792
ns1.00
vgg16(32, 32, 3, 64)/forward/GPU/CUDA
7024831
ns7030918
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s)
2012811729
ns2004923333
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s)
1612557354
ns1605834520.5
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s)
1572362875
ns1586581187
ns0.99
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s)
2635509917
ns2623747583
ns1.00
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA
25977885
ns25764798
ns1.01
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s)
521667
ns524375
ns0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s)
439708
ns398542
ns1.10
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s)
2731999.5
ns2053208
ns1.33
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s)
865959
ns865312.5
ns1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA
47570
ns47173
ns1.01
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s)
1889812.5
ns1905417
ns0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s)
2797000
ns1742250
ns1.61
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s)
16236125
ns15020687.5
ns1.08
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s)
2650416
ns2728374.5
ns0.97
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA
249198.5
ns248474
ns1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s)
1925354.5
ns1969542
ns0.98
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s)
5070458
ns1845500
ns2.75
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s)
16364250
ns14844708.5
ns1.10
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s)
2752417
ns2764666.5
ns1.00
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s)
1507000
ns1573416
ns0.96
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s)
1228583
ns947833.5
ns1.30
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s)
1068166.5
ns1240792
ns0.86
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s)
2208000
ns2252813
ns0.98
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA
589459
ns585933
ns1.01
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s)
5951958
ns5942334
ns1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s)
4655854.5
ns8620916
ns0.54
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s)
27127042
ns25125208
ns1.08
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s)
6596583
ns7308250
ns0.90
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA
1347970
ns1387329
ns0.97
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s)
12790458
ns13102000
ns0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s)
12034667
ns12175645.5
ns0.99
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s)
22409521.5
ns20369042
ns1.10
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s)
10615250
ns10679959
ns0.99
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s)
2292
ns2500
ns0.92
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s)
2500
ns4917
ns0.51
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s)
3125
ns2875
ns1.09
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s)
2292
ns2458
ns0.93
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA
24734
ns24933
ns0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s)
7416
ns7375
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s)
7416.5
ns7042
ns1.05
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s)
7666
ns7333.5
ns1.05
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s)
7167
ns7125
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA
210121.5
ns213416.5
ns0.98
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s)
8250
ns8209
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s)
8250
ns8209
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s)
8417
ns8500
ns0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s)
5958
ns5916
ns1.01
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s)
10500
ns10625
ns0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s)
12896
ns16167
ns0.80
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s)
10666
ns10875
ns0.98
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s)
7083
ns7187.5
ns0.99
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA
24767
ns25092
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s)
20208
ns19916
ns1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s)
20041
ns19875
ns1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s)
20209
ns20333
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s)
20125
ns19854.5
ns1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA
230521
ns233641
ns0.99
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s)
23625
ns23562.5
ns1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s)
23500
ns23416
ns1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s)
23750
ns23750
ns1
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s)
21125
ns21208
ns1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s)
28292
ns28958
ns0.98
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s)
28500
ns28416
ns1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s)
28708
ns28500
ns1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s)
47396
ns46041
ns1.03
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA
25741
ns26269
ns0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s)
220542
ns226146
ns0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s)
281604.5
ns274895.5
ns1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s)
4282416
ns4238916.5
ns1.01
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s)
146208.5
ns145125
ns1.01
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA
210161.5
ns211030
ns1.00
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s)
330833
ns337437.5
ns0.98
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s)
321584
ns316687.5
ns1.02
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s)
763437.5
ns800125
ns0.95
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s)
161895.5
ns161750
ns1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s)
2000
ns1917
ns1.04
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s)
1875
ns2250
ns0.83
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s)
2541
ns2167
ns1.17
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s)
1666
ns1875
ns0.89
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA
23005
ns23176
ns0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s)
5250
ns5333
ns0.98
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s)
5458
ns5000
ns1.09
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s)
5541
ns5417
ns1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s)
5500
ns5167
ns1.06
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA
256932
ns256165.5
ns1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s)
11250
ns11250
ns1
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s)
11250
ns11417
ns0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s)
11458
ns11375
ns1.01
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s)
7000
ns6958
ns1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
79995437.5
ns79988042
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
49006333
ns47941125
ns1.02
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
43270416.5
ns45005708
ns0.96
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
151307500
ns151719833
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA
2675260
ns2721409
ns0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
673208958
ns663734583
ns1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
413629250
ns409152750
ns1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
396318917
ns398007625
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
684603375
ns690336167
ns0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
14606131
ns14673781
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
713812625
ns709715271
ns1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
676256333
ns686379083
ns0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
1061049792
ns980829291
ns1.08
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
1000454667
ns997377042
ns1.00
This comment was automatically generated by workflow using github-action-benchmark.