Skip to content

LuxDL/Lux.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

8a1cb65 Β· Jan 24, 2025
Jan 24, 2025
Jan 24, 2025
Aug 21, 2023
Nov 3, 2024
Jan 24, 2025
Jan 24, 2025
Jan 24, 2025
Jan 24, 2025
Jan 24, 2025
Jan 24, 2025
Sep 15, 2024
Nov 17, 2024
Nov 3, 2024
Dec 16, 2024
Apr 1, 2022
Jan 24, 2025
Jan 8, 2025

Repository files navigation

GitHub Discussions Latest Docs Stable Docs

CI CI (pre-release) Build status codecov Benchmarks

Downloads Downloads

JET Testing Aqua QA ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

Elegant & Performant Scientific Machine Learning in Julia

A Pure Julia Deep Learning Framework designed for Scientific Machine Learning

πŸ’» Installation

import Pkg
Pkg.add("Lux")

Tip

If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.

Packages Stable Version Monthly Downloads Total Downloads Build Status
πŸ“¦ Lux.jl
β”” πŸ“¦ LuxLib.jl
β”” πŸ“¦ LuxCore.jl
β”” πŸ“¦ MLDataDevices.jl
β”” πŸ“¦ WeightInitializers.jl
β”” πŸ“¦ LuxTestUtils.jl
β”” πŸ“¦ LuxCUDA.jl

🀸 Quickstart

using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Get the device determined by Lux
dev = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev

# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))

## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state)

## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state)

🀸 Quickstart with Reactant

using Lux, Random, Optimisers, Reactant, Enzyme

rng = Random.default_rng()
Random.seed!(rng, 0)

model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))

dev = reactant_device()

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, 128, 2) |> dev

# We need to compile the model before we can use it.
model_forward = @compile model(x, ps, Lux.testmode(st))
model_forward(x, ps, Lux.testmode(st))

# Gradients can be computed using Enzyme
@jit Enzyme.gradient(Reverse, sum ∘ first ∘ Lux.apply, Const(model), x, ps, Const(st))

# All of this can be automated using the TrainState API
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

gs, loss, stats, train_state = Training.single_train_step!(
    AutoEnzyme(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)

πŸ“š Examples

Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.

πŸ†˜ Getting Help

For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.

πŸ§‘β€πŸ”¬ Citation

If you found this library to be useful in academic work, then please cite:

@software{pal2023lux,
  author    = {Pal, Avik},
  title     = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
  month     = apr,
  year      = 2023,
  note      = {If you use this software, please cite it as below.},
  publisher = {Zenodo},
  version   = {v1.4.2},
  doi       = {10.5281/zenodo.7808903},
  url       = {https://doi.org/10.5281/zenodo.7808903},
  swhid     = {swh:1:dir:1a304ec3243961314a1cc7c1481a31c4386c4a34;origin=https://doi.org/10.5281/zenodo.7808903;visit=swh:1:snp:e2bbe43b14bde47c4ddf7e637eb7fc7bd10db8c7;anchor=swh:1:rel:2c0c0ff927e7bfe8fc8bc43fd553ab392a6eb403;path=/}
}

@thesis{pal2023efficient,
  title     = {{On Efficient Training \& Inference of Neural Differential Equations}},
  author    = {Pal, Avik},
  year      = {2023},
  school    = {Massachusetts Institute of Technology}
}

Also consider starring our github repo.

πŸ§‘β€πŸ’» Contributing

This section is somewhat incomplete. You can contribute by contributing to finishing this section 😜.

πŸ§ͺ Testing

The full test of Lux.jl takes a long time, here's how to test a portion of the code.

For each @testitem, there are corresponding tags, for example:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers]

For example, let's consider the tests for SkipConnection:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers] begin
    ...
end

We can test the group to which SkipConnection belongs by testing core_layers. To do so set the LUX_TEST_GROUP environment variable, or rename the tag to further narrow the test scope:

export LUX_TEST_GROUP="core_layers"

Or directly modify the default test tag in runtests.jl:

# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "core_layers"))

But be sure to restore the default value "all" before submitting the code.

Furthermore if you want to run a specific test based on the name of the testset, you can use TestEnv.jl as follows. Start with activating the Lux environment and then run the following:

using TestEnv; TestEnv.activate(); using ReTestItems;

# Assuming you are in the main directory of Lux
ReTestItems.runtests("tests/"; name = "NAME OF THE TEST")

For the SkipConnection tests that would be:

ReTestItems.runtests("tests/"; name = "SkipConnection")