Skip to content

Commit

Permalink
fix: force_inline inside generated functions to avoid recursion issues
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 17, 2024
1 parent 70484be commit 8722ba1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.3"
version = "1.0.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
18 changes: 9 additions & 9 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ function initialstates(rng::AbstractRNG, l::SkipConnection{T, <:AbstractLuxLayer
end

function (skip::SkipConnection)(x, ps, st::NamedTuple)
mx, st = apply(skip.layers, x, ps, st)
mx, st = @inline apply(skip.layers, x, ps, st)
return skip.connection(mx, x), st
end

function (skip::SkipConnection{<:AbstractLuxLayer, <:AbstractLuxLayer})(
x, ps, st::NamedTuple)
mx, st1 = apply(skip.layers, x, ps.layers, st.layers)
y, st2 = apply(skip.connection, (mx, x), ps.connection, st.connection)
mx, st1 = @inline apply(skip.layers, x, ps.layers, st.layers)
y, st2 = @inline apply(skip.connection, (mx, x), ps.connection, st.connection)
return y, (layers=st1, connection=st2)
end

Expand Down Expand Up @@ -180,7 +180,7 @@ end
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = []
append!(calls,
[:(($(y_symbols[i]), $(st_symbols[i])) = LuxCore.apply(
[:(($(y_symbols[i]), $(st_symbols[i])) = @inline apply(
layers.$(names[i]), $(getinput(i)), ps.$(names[i]), st.$(names[i])))
for i in 1:N])
push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),)))))
Expand Down Expand Up @@ -273,7 +273,7 @@ BranchLayer(; name::NAME_TYPE=nothing, kwargs...) = BranchLayer((; kwargs...), n
st_symbols = [gensym() for _ in 1:N]
calls = []
append!(calls,
[:(($(y_symbols[i]), $(st_symbols[i])) = apply(
[:(($(y_symbols[i]), $(st_symbols[i])) = @inline apply(
layers.$(names[i]), x, ps.$(names[i]), st.$(names[i]))) for i in 1:N])
push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),)))))
push!(calls, :(return tuple($(Tuple(y_symbols)...)), st))
Expand Down Expand Up @@ -377,7 +377,7 @@ end
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
append!(calls,
[:(($(y_symbols[i]), $(st_symbols[i])) = apply(
[:(($(y_symbols[i]), $(st_symbols[i])) = @inline apply(
layers.$(names[i]), $(y_symbols[N + 1]), ps.$(names[i]), st.$(names[i]));
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
for i in 1:N])
Expand Down Expand Up @@ -484,7 +484,7 @@ wrap_functions_in_chain_call(x) = x
N = length(fields)
x_symbols = vcat([:x], [gensym() for _ in 1:N])
st_symbols = [gensym() for _ in 1:N]
calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = apply(
calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = @inline apply(
layers.$(fields[i]), $(x_symbols[i]), ps.$(fields[i]), st.$(fields[i])))
for i in 1:N]
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
Expand Down Expand Up @@ -570,7 +570,7 @@ Maxout(f::Function, n_alts::Int) = Maxout(ntuple(Returns(f()), n_alts)...)
N = length(fields)
y_symbols = [gensym() for _ in 1:N]
st_symbols = [gensym() for _ in 1:N]
calls = [:(($(y_symbols[i]), $(st_symbols[i])) = apply(
calls = [:(($(y_symbols[i]), $(st_symbols[i])) = @inline apply(
layers.$(fields[i]), x, ps.$(fields[i]), st.$(fields[i]))) for i in 1:N]
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
push!(calls, :(res = max.($(Tuple(y_symbols)...))))
Expand Down Expand Up @@ -661,7 +661,7 @@ end
known(IJ) && push!(calls, :($(xs[1]) = x))
for i in 1:known(N)
push!(calls,
:(($(xs[i + known(IJ)]), $(sts[i])) = apply(
:(($(xs[i + known(IJ)]), $(sts[i])) = @inline apply(
model, $(known(IJ) ? :(($(xs[i]), x)) : :x),
ps, $(i == 1 ? :st : sts[i - 1]))))
end
Expand Down
31 changes: 31 additions & 0 deletions test/layers/containers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,34 @@ end
end
end
end

@testitem "Coupling Type Instability: #416" setup=[SharedTestSetup] tags=[:core_layers] begin
using ComponentArrays, Random

rng = Random.default_rng()

froggie = Chain(
BranchLayer(NoOpLayer(), NoOpLayer()),
Parallel(
nothing,
Parallel(
+,
Dense(1 => 1),
NoOpLayer()
),
WrappedFunction(first)
)
)

@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
x = [1.0] |> aType
ps, st = Lux.setup(rng, froggie)
st = st |> dev

ps_nt = ps |> dev
@test @inferred(froggie(x, ps_nt, st)) isa Any

ps_ca = ps |> ComponentArray |> dev
@test @inferred(froggie(x, ps_ca, st)) isa Any
end
end

0 comments on commit 8722ba1

Please sign in to comment.