Skip to content

Commit

Permalink
docs: add a PINN tutorial with nested AD
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 15, 2024
1 parent e7caa61 commit 3ca41c8
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 8 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
annotate_untyped_fields_with_any = false
join_lines_based_on_source = false
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ LocalPreferences.toml
data/

benchmarks/results

# Generated by tutorials
pinn_nested_ad.gif
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ pages = [
"Intermediate" => [
"tutorials/intermediate/1_NeuralODE.md",
"tutorials/intermediate/2_BayesianNN.md",
"tutorials/intermediate/3_HyperNet.md"
"tutorials/intermediate/3_HyperNet.md",
"tutorials/intermediate/4_PINN2DPDE.md"
],
"Advanced" => [
"tutorials/advanced/1_GravitationalWaveForm.md"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ export default defineConfig({
text: 'Intermediate', collapsed: false, items: [
{ text: 'MNIST Classification using Neural ODEs', link: '/tutorials/intermediate/1_NeuralODE' },
{ text: 'Bayesian Neural Network', link: '/tutorials/intermediate/2_BayesianNN' },
{ text: 'Training a HyperNetwork on MNIST and FashionMNIST', link: '/tutorials/intermediate/3_HyperNet' }]
{ text: 'Training a HyperNetwork on MNIST and FashionMNIST', link: '/tutorials/intermediate/3_HyperNet' },
{ text: 'Training a PINN on 2D PDE', link: '/tutorials/intermediate/4_PINN2DPDE' }]
},
{
text: 'Advanced', collapsed: false, items: [
Expand Down
Binary file added docs/src/public/pinn_nested_ad.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ const intermediate = [
href: "intermediate/3_HyperNet",
src: "../hypernet.jpg",
caption: "Training a HyperNetwork",
desc: "Train a hypernetwork to work on multiple datasets by predicting neural network parameters."
desc: "Train a hypernetwork to work on multiple datasets by predicting NN parameters."
},
{
href: "intermediate/4_PINN2DPDE",
src: "../pinn_nested_ad.gif",
caption: "Training a PINN",
desc: "Train a PINN to solve 2D PDEs (using Nested AD)."
}
];
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const INTERMEDIATE_TUTORIALS = [
"NeuralODE/main.jl" => "CUDA",
"BayesianNN/main.jl" => "CPU",
"HyperNet/main.jl" => "CUDA",
"PINN2DPDE/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down
3 changes: 2 additions & 1 deletion examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) w
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end

Expand Down
3 changes: 2 additions & 1 deletion examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Boo
## Construct the Neural ODE Model
model = Chain(FlattenLayer(),
Dense(784 => 20, tanh),
model_fn(Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
model_fn(
Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep=false, reltol=1.0f-3,
abstol=1.0f-3, save_start=false, sensealg),
Base.Fix1(diffeqsol_to_array, 20),
Expand Down
14 changes: 14 additions & 0 deletions examples/PINN2DPDE/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
232 changes: 232 additions & 0 deletions examples/PINN2DPDE/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# # Training a PINN on 2D PDE

# In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the
# system from [NeuralPDE Tutorials](https://docs.sciml.ai/NeuralPDE/stable/tutorials/gpu/).
# However, we will be using our custom loss function and use nested AD capabilities of
# Lux.jl.

# This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to
# the package: [NeuralPDE.jl](https://github.com/SciML/NeuralPDE.jl).

# ## Package Imports

using ADTypes, Lux, Optimisers, Zygote, Random, Printf, Statistics, MLUtils, OnlineStats,
CairoMakie
using LuxCUDA

CUDA.allowscalar(false)

const gdev = gpu_device()
const cdev = cpu_device()

# ## Problem Definition

# Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem
# with first order derivatives, so that we can compute the gradients of the loss using
# 2nd order AD.

# TODO: Equations for the PDE

# ## Define the Neural Networks

# All the networks take 3 input variables and output a scalar value. Here, we will define a
# a wrapper over the 3 networks, so that we can train them using
# [`Training.TrainState`](@ref).

struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)}
u::U
v::V
w::W
end

function create_mlp(act, hidden_dims)
return Chain(
Dense(3 => hidden_dims, act),
Dense(hidden_dims => hidden_dims, act),
Dense(hidden_dims => hidden_dims, act),
Dense(hidden_dims => 1)
)
end

function PINN(; hidden_dims::Int=32)
return PINN(
create_mlp(tanh, hidden_dims),
create_mlp(tanh, hidden_dims),
create_mlp(tanh, hidden_dims)
)
end

# ## Define the Loss Functions

# We will define a custom loss function to compute the loss using 2nd order AD. We
# will use the following loss function

@views function physics_informed_loss_function(
u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray)
∂u_∂xyt = only(Zygote.gradient(sum u, xyt))
∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :]
∂v_∂x = only(Zygote.gradient(sum v, xyt))[1:1, :]
v_xyt = v(xyt)
∂w_∂y = only(Zygote.gradient(sum w, xyt))[2:2, :]
w_xyt = w(xyt)
return (
mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) +
mean(abs2, v_xyt .- ∂u_∂x) +
mean(abs2, w_xyt .- ∂u_∂y)
)
end

# Additionally, we need to compute the loss wrt the boundary conditions.

function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
return MSELoss()(u(xyt), target)
end

function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u)
v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v)
w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w)
physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt)
data_loss = mse_loss_function(u_net, target_data, xyt)
bc_loss = mse_loss_function(u_net, target_bc, xyt_bc)
loss = physics_loss + data_loss + bc_loss
return (
loss,
(; u=u_net.st, v=v_net.st, w=w_net.st),
(; physics_loss, data_loss, bc_loss)
)
end

# ## Generate the Data

# We will generate some random data to train the model on. We will take data on a square
# spatial and temporal domain $x \in [0, 2]$, $y \in [0, 2]$, and $t \in [0, 2]$. Typically,
# you want to be smarter about the sampling process, but for the sake of simplicity, we will
# skip that.

analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])

begin
grid_len = 16

grid = range(0.0f0, 2.0f0; length=grid_len)
xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])

target_data = reshape(analytical_solution(xyt), 1, :)

bc_len = 512

x = collect(range(0.0f0, 2.0f0; length=bc_len))
y = collect(range(0.0f0, 2.0f0; length=bc_len))
t = collect(range(0.0f0, 2.0f0; length=bc_len))

xyt_bc = hcat(
stack((x, y, zeros(Float32, bc_len)); dims=1),
stack((zeros(Float32, bc_len), y, t); dims=1),
stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
stack((x, zeros(Float32, bc_len), t); dims=1),
stack((x, ones(Float32, bc_len) .* 2, t); dims=1)
)
target_bc = reshape(analytical_solution(xyt_bc), 1, :)

min_target_bc, max_target_bc = extrema(target_bc)
min_data, max_data = extrema(target_data)
min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)

xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end
nothing #hide

# ## Training

function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0,
maxiters::Int=50000, hidden_dims::Int=32)
rng = Random.default_rng()
Random.seed!(rng, seed)

pinn = PINN(; hidden_dims)
ps, st = Lux.setup(rng, pinn) |> gdev

bc_dataloader = DataLoader((xyt_bc, target_bc); batchsize=32, shuffle=true) |> gdev
pde_dataloader = DataLoader((xyt, target_data); batchsize=32, shuffle=true) |> gdev

train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0))
lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0)

total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
_ -> Lag(Float32, 32), 4)

iter = 1
for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip(
Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader))
Optimisers.adjust!(train_state, lr(iter))

_, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), loss_function, (
xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
train_state)

fit!(total_loss_tracker, loss)
fit!(physics_loss_tracker, stats.physics_loss)
fit!(data_loss_tracker, stats.data_loss)
fit!(bc_loss_tracker, stats.bc_loss)

mean_loss = mean(OnlineStats.value(total_loss_tracker))
mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))

isnan(loss) && throw(ArgumentError("NaN Loss Detected"))

if iter % 500 == 1 || iter == maxiters
@printf "Iteration: [%5d / %5d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
(%.9f) \t Data Loss: %.9f (%.9f) \t BC \
Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss
end

iter += 1
iter maxiters && break
end

return StatefulLuxLayer{true}(
pinn, cdev(train_state.parameters), cdev(train_state.states))
end

trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
trained_u = Lux.testmode(StatefulLuxLayer{true}(
trained_model.model.u, trained_model.ps.u, trained_model.st.u))
nothing #hide

# ## Visualizing the Results
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])

u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))

grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val

begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
Colorbar(fig[1, 2]; limits=extrema(stack(errs)))

CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
ax.title = "Abs. Predictor Error | Time: $(ts[i])"
err = errs[i]
contour!(ax, xs, ys, err; levels=10, linewidth=2)
heatmap!(ax, xs, ys, err)
return fig
end

fig
end
nothing #hide

# ![](pinn_nested_ad.gif)
3 changes: 2 additions & 1 deletion src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ function CompactLuxLayer(dispatch::StaticSymbol, f::F, name::NAME_TYPE,
NamedTuple((name => CompactMacroImpl.kwarg_descriptor(val),)))
end
end
return CompactLuxLayer(dispatch, f, name, str, setup_strings, NamedTuple((; layers...)),
return CompactLuxLayer(
dispatch, f, name, str, setup_strings, NamedTuple((; layers...)),
CompactMacroImpl.ValueStorage(; others...), splatted_kwargs)
end

Expand Down
3 changes: 2 additions & 1 deletion test/utils_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ end
x = randn(3, 2)
@test Lux.recursive_eltype(x, Val(true)) == Float64

x_wrapped = (ForwardDiff.Dual.(x), ForwardDiff.Dual(2.0), ReverseDiff.track.(x),
x_wrapped = (
ForwardDiff.Dual.(x), ForwardDiff.Dual(2.0), ReverseDiff.track.(x),
ReverseDiff.track(2.0), ReverseDiff.track(x),
Tracker.param.(x), Tracker.param(x), Tracker.param(2.0))

Expand Down

1 comment on commit 3ca41c8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: 3ca41c8 Previous: e7caa61 Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 412583 ns 412125 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 322979.5 ns 243958 ns 1.32
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 243687.5 ns 323708 ns 0.75
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 738583 ns 739541 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 43341 ns 43923 ns 0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1331541.5 ns 1371937.5 ns 0.97
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 2407792 ns 1261354.5 ns 1.91
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 16439208.5 ns 14008417 ns 1.17
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 2194563 ns 2194708 ns 1.00
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 205021 ns 205355.5 ns 1.00
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1426479.5 ns 1417791 ns 1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 895625 ns 887292 ns 1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 1543917 ns 1812208 ns 0.85
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 2206208.5 ns 2212208 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1777646 ns 1733292 ns 1.03
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1078167 ns 1102709 ns 0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1530958 ns 1521208.5 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3007208 ns 3013042 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 207880 ns 206261 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12178459 ns 12142666.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8815750 ns 8846250 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9208125 ns 9207354 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18565750 ns 18580625 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1492124 ns 1487526 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17280708 ns 17293999.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13973458 ns 13993542 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14487354.5 ns 14483333.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21838146 ns 21828604 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 249950041 ns 250797875.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148180333 ns 148823333 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 116724083 ns 115883354 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 447231500 ns 454090334 ns 0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5449958 ns 5479922 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1223396291 ns 1224986000 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 929732416 ns 928765000 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 832528750 ns 827976542 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1633536917 ns 1644101834 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 31232077 ns 31171453 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1127648666 ns 1135178542 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1002243833.5 ns 999714166.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1330111333.5 ns 1303828854 ns 1.02
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1732141146 ns 1741741896 ns 0.99
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1037166 ns 1090792 ns 0.95
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1626521 ns 1637666 ns 0.99
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3777708 ns 3582708 ns 1.05
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 781583 ns 786167 ns 0.99
lenet(28, 28, 1, 32)/forward/GPU/CUDA 262013.5 ns 262029.5 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 3044041.5 ns 2971792 ns 1.02
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4097042 ns 4123750 ns 0.99
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 11116208 ns 10426437.5 ns 1.07
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3145292 ns 3148229 ns 1.00
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1092823.5 ns 1092866.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2334666.5 ns 2326979.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1419459 ns 1322083 ns 1.07
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1571666 ns 1673459 ns 0.94
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 4190125 ns 4208000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 208094.5 ns 208009 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 19396000 ns 19416729 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16089229 ns 16098667 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 17212895.5 ns 17355625.5 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 25854646 ns 25850083 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1594895 ns 1585986 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 34213875 ns 34150479.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 31031917 ns 30929208 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 31076583 ns 30921458 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 36708292 ns 36974958 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4528125 ns 4531521 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2779625 ns 2560770.5 ns 1.09
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2669750 ns 2921667 ns 0.91
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 8386375 ns 8371042 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 420922 ns 418970 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 38785750 ns 38979334 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 32129375 ns 32095416.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 32277375 ns 32177750 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 51825125 ns 52201375 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2627381 ns 2623471.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 88427125 ns 88784146 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 113463375 ns 115080917 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 227295542 ns 219472208 ns 1.04
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 74279083 ns 74369625 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 268165959 ns 268757584 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 158761895.5 ns 156384958 ns 1.02
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 123751562.5 ns 126657687.5 ns 0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 487404084 ns 493203959 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 6976350 ns 7020761 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1472349770.5 ns 1470831958.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1171105895.5 ns 1166912584 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1066855854.5 ns 1069943646 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 2006883229.5 ns 2001019750 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 34648809 ns 34698558 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1717718833 ns 1717061375 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1535202521 ns 1532673563 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1878248417 ns 1763256208 ns 1.07
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 2205912250 ns 2235971083 ns 0.99
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2014792 ns 2033917 ns 0.99
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 3001875 ns 3001750 ns 1.00
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 8115875 ns 6995583 ns 1.16
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2432666 ns 2410250.5 ns 1.01
lenet(28, 28, 1, 128)/forward/GPU/CUDA 266572 ns 277586.5 ns 0.96
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9287083.5 ns 9596208.5 ns 0.97
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 12062125 ns 11959208.5 ns 1.01
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 25629896 ns 24601937.5 ns 1.04
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 11743625.5 ns 11737000 ns 1.00
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1194959.5 ns 1205760 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 385030062.5 ns 380715583 ns 1.01
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 288360083 ns 309134958 ns 0.93
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 255891729 ns 240484604 ns 1.06
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 454057438 ns 455377687.5 ns 1.00
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4834603 ns 4901230.5 ns 0.99
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1159548125 ns 1156536042 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 928639750 ns 934444375 ns 0.99
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 1041025834 ns 937955958 ns 1.11
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 1400301917 ns 1404766250 ns 1.00
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 17860514 ns 19140053 ns 0.93
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1063417 ns 1044417 ns 1.02
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 2031167 ns 1665375 ns 1.22
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 6215979 ns 4659125 ns 1.33
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1289687.5 ns 1381458 ns 0.93
lenet(28, 28, 1, 64)/forward/GPU/CUDA 274190 ns 277017 ns 0.99
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6281917 ns 6471291 ns 0.97
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 12390375 ns 13171521 ns 0.94
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 21407437 ns 19155167 ns 1.12
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 6091125 ns 6084416 ns 1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1242182 ns 1234554.5 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70428792 ns 70412479 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43611000 ns 43727396 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39667125 ns 39654166 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132442542 ns 132669229.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1932803.5 ns 1946871.5 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 355601563 ns 356558937 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 270115458 ns 269901750 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 253762937.5 ns 253901458 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 535226770.5 ns 543773875 ns 0.98
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 12278246 ns 12296125.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 399133667 ns 394712500 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 394300875 ns 374705250 ns 1.05
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 679147666.5 ns 695586875 ns 0.98
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 710284375 ns 709322375 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1194640667 ns 1186149833 ns 1.01
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 689604666 ns 826656438 ns 0.83
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 645000312 ns 633369958 ns 1.02
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 1774100021 ns 1780054896 ns 1.00
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12543961 ns 12309879.5 ns 1.02
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3679223834 ns 3667291125 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2826435875 ns 2822189750 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2707509167 ns 2708758541 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 5055415292 ns 5073574916 ns 1.00
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49466162 ns 49860253.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3397917 ns 3405875.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2077646 ns 2065250 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2508042 ns 2540250 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6018583.5 ns 6014604 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 330593.5 ns 343548 ns 0.96
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25938625.5 ns 25973124.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 18943166.5 ns 18895979.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19554000 ns 19531416.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 39254458 ns 39260291 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2474870 ns 2465534 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 55544542 ns 55438083 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 82396771 ns 82776500 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 173728042 ns 170298791.5 ns 1.02
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 45516333 ns 45669667 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1779250 ns 1778750 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1102875 ns 1090500 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1584416 ns 1558104.5 ns 1.02
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3020250 ns 3024604.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 214596.5 ns 213362 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12529625 ns 12540583.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9208687.5 ns 9227000 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9658437 ns 9657125 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18982292 ns 19007583 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1539988 ns 1541996 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17638625 ns 17663416.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14351250 ns 14318000 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14600292 ns 14637687.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 22151625 ns 22165500 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70460291 ns 70531042 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43523750 ns 43664104 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39539833 ns 39661792 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132477750 ns 132673333 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1881252.5 ns 1958870.5 ns 0.96
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 359873333.5 ns 360097833 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 345894479 ns 347045875 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 305415375 ns 303749416 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 725056958 ns 730460916 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13380753 ns 13376571.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 421101688 ns 417572000 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 421140000 ns 427267750 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 748051292 ns 741225500 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 715116625 ns 714873709 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1695875 ns 1698792 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1355249.5 ns 1054000 ns 1.29
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1159396 ns 1350854 ns 0.86
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2409792 ns 2441916 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 588933.5 ns 591065.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 8963979.5 ns 8953166 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 13000041 ns 13693270.5 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 33199749.5 ns 30593292 ns 1.09
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 9835041 ns 9833291 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1481667.5 ns 1490752.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 17747854.5 ns 18112458 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 17252604.5 ns 17467500.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 31103916.5 ns 29723417 ns 1.05
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 14355042 ns 14333291 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 669500.5 ns 678667 ns 0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 556667 ns 520708 ns 1.07
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1058916.5 ns 1030125 ns 1.03
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 725958 ns 724041.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 48545 ns 48348.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1480334 ns 1565250 ns 0.95
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1045958 ns 1021042 ns 1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1649187 ns 1403604 ns 1.17
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 2243312 ns 2276771 ns 0.99
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 241138.5 ns 240894 ns 1.00
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1523125 ns 1558958.5 ns 0.98
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1079125 ns 1066687 ns 1.01
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1484062.5 ns 1686104.5 ns 0.88
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 2259292 ns 2226166 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3401000.5 ns 3395333 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2066583.5 ns 2056708 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2508167 ns 2516479.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6010687.5 ns 6008875 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 286726 ns 288149 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24062625 ns 24072041.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17164833 ns 17187916 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17141292 ns 17151041.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 37492334 ns 37508458.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2402302 ns 2405382 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 53536916 ns 53630708.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 83532937.5 ns 85883417 ns 0.97
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 171297229.5 ns 169400792 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 44566417 ns 44653333.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 249899167 ns 250291979 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 147932291 ns 148449041 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 116470479.5 ns 115844229 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 450194500 ns 454457542 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5441310 ns 5457486 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1101316834 ns 1104315584 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 855731854 ns 856185292 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 827979937.5 ns 826815646 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1754256625 ns 1779319041 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 28790457 ns 28848435 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1014882938 ns 1008544916.5 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 964872792 ns 978883958 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1306006041 ns 1273700875 ns 1.03
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1738222270.5 ns 1738185250 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1311167 ns 1308271 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 957750 ns 736416 ns 1.30
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 703375 ns 970250.5 ns 0.72
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 1939625 ns 1992812.5 ns 0.97
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 571080 ns 573191 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 6006770.5 ns 5738375 ns 1.05
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 6329521 ns 9161583 ns 0.69
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 25490292 ns 23815937.5 ns 1.07
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 7082792 ns 7075209 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1359778.5 ns 1420790.5 ns 0.96
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 11590167 ns 10487208 ns 1.11
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 10239645.5 ns 10551083 ns 0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 18128834 ns 17215062.5 ns 1.05
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 8225354.5 ns 8751250 ns 0.94
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 362875 ns 407270.5 ns 0.89
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 363208 ns 377167 ns 0.96
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 3032958 ns 1964791 ns 1.54
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 87687.5 ns 88084 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 28003 ns 28331 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 388708 ns 349125.5 ns 1.11
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 440375 ns 443959 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4703354 ns 4468750 ns 1.05
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 259209 ns 258666 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 220116.5 ns 226875.5 ns 0.97
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 419708.5 ns 379791 ns 1.11
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 470625 ns 474250 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4962583 ns 4250750 ns 1.17
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 270875 ns 270875 ns 1
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 307959 ns 352916 ns 0.87
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 298292 ns 315333 ns 0.95
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 764833 ns 733437.5 ns 1.04
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 54917 ns 54271 ns 1.01
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 27854 ns 28442.5 ns 0.98
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 352750 ns 296333.5 ns 1.19
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 336000 ns 341000 ns 0.99
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 887229.5 ns 547833 ns 1.62
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 151687.5 ns 151666 ns 1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 205379 ns 211615 ns 0.97
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 366770.5 ns 310375 ns 1.18
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 350334 ns 354833 ns 0.99
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 470875 ns 397666.5 ns 1.18
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 151188 ns 150875 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 606032542 ns 602928541 ns 1.01
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 429552104 ns 428260667 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 384048291 ns 376195062.5 ns 1.02
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 874614500 ns 872683792 ns 1.00
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7024831 ns 7030918 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 2012811729 ns 2004923333 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1612557354 ns 1605834520.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1572362875 ns 1586581187 ns 0.99
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 2635509917 ns 2623747583 ns 1.00
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 25977885 ns 25764798 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 521667 ns 524375 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 439708 ns 398542 ns 1.10
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 2731999.5 ns 2053208 ns 1.33
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 865959 ns 865312.5 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 47570 ns 47173 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1889812.5 ns 1905417 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 2797000 ns 1742250 ns 1.61
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 16236125 ns 15020687.5 ns 1.08
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 2650416 ns 2728374.5 ns 0.97
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 249198.5 ns 248474 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 1925354.5 ns 1969542 ns 0.98
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 5070458 ns 1845500 ns 2.75
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 16364250 ns 14844708.5 ns 1.10
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 2752417 ns 2764666.5 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1507000 ns 1573416 ns 0.96
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 1228583 ns 947833.5 ns 1.30
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1068166.5 ns 1240792 ns 0.86
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2208000 ns 2252813 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 589459 ns 585933 ns 1.01
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5951958 ns 5942334 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 4655854.5 ns 8620916 ns 0.54
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 27127042 ns 25125208 ns 1.08
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 6596583 ns 7308250 ns 0.90
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1347970 ns 1387329 ns 0.97
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 12790458 ns 13102000 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 12034667 ns 12175645.5 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 22409521.5 ns 20369042 ns 1.10
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 10615250 ns 10679959 ns 0.99
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 2292 ns 2500 ns 0.92
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2500 ns 4917 ns 0.51
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 3125 ns 2875 ns 1.09
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2292 ns 2458 ns 0.93
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24734 ns 24933 ns 0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7416 ns 7375 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 7416.5 ns 7042 ns 1.05
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7666 ns 7333.5 ns 1.05
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7167 ns 7125 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 210121.5 ns 213416.5 ns 0.98
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8250 ns 8209 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8250 ns 8209 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8417 ns 8500 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 5958 ns 5916 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10500 ns 10625 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 12896 ns 16167 ns 0.80
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 10666 ns 10875 ns 0.98
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7083 ns 7187.5 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 24767 ns 25092 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 20208 ns 19916 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 20041 ns 19875 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 20209 ns 20333 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 20125 ns 19854.5 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 230521 ns 233641 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 23625 ns 23562.5 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 23500 ns 23416 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 23750 ns 23750 ns 1
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 21125 ns 21208 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28292 ns 28958 ns 0.98
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28500 ns 28416 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28708 ns 28500 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 47396 ns 46041 ns 1.03
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 25741 ns 26269 ns 0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 220542 ns 226146 ns 0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 281604.5 ns 274895.5 ns 1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4282416 ns 4238916.5 ns 1.01
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 146208.5 ns 145125 ns 1.01
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 210161.5 ns 211030 ns 1.00
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 330833 ns 337437.5 ns 0.98
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 321584 ns 316687.5 ns 1.02
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 763437.5 ns 800125 ns 0.95
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 161895.5 ns 161750 ns 1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 2000 ns 1917 ns 1.04
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 1875 ns 2250 ns 0.83
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2541 ns 2167 ns 1.17
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1666 ns 1875 ns 0.89
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 23005 ns 23176 ns 0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5250 ns 5333 ns 0.98
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5458 ns 5000 ns 1.09
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5541 ns 5417 ns 1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5500 ns 5167 ns 1.06
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 256932 ns 256165.5 ns 1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 11250 ns 11250 ns 1
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 11250 ns 11417 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 11458 ns 11375 ns 1.01
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 7000 ns 6958 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 79995437.5 ns 79988042 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 49006333 ns 47941125 ns 1.02
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 43270416.5 ns 45005708 ns 0.96
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 151307500 ns 151719833 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2675260 ns 2721409 ns 0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 673208958 ns 663734583 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 413629250 ns 409152750 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 396318917 ns 398007625 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 684603375 ns 690336167 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 14606131 ns 14673781 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 713812625 ns 709715271 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 676256333 ns 686379083 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1061049792 ns 980829291 ns 1.08
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 1000454667 ns 997377042 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.