Skip to content

Commit

Permalink
CPUs go burrr....
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 1, 2022
1 parent 05a3427 commit bf42a79
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 131 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 Avik Pal <avikpal@iitk.ac.in> and contributors
Copyright (c) 2022 Avik Pal <avikpal@mit.edu> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.1"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -19,9 +20,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
CUDA = "3"
ChainRulesCore = "1"
FillArrays = "0.13"
Flux = "0.12"
NNlib = "0.8"
NNlibCUDA = "0.2"
Octavian = "0.3"
Setfield = "0.8"
julia = "1"

Expand Down
40 changes: 4 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,46 +134,14 @@ ExplicitFluxLayers.apply(model, x, ps, st)
gradient(p -> sum(ExplicitFluxLayers.apply(model, x, p, st)[1]), ps)
```

### Using Other AD Libraries -- Yota

NOTE: This is somewhat experimental

```julia
using ExplicitFluxLayers, Flux, Optimisers, Random, Yota, Zygote

model = ExplicitFluxLayers.Chain(
# ExplicitFluxLayers.BatchNorm(128),
ExplicitFluxLayers.Dense(128, 256, tanh),
# ExplicitFluxLayers.BatchNorm(256),
ExplicitFluxLayers.Chain(
ExplicitFluxLayers.Dense(256, 1, tanh),
ExplicitFluxLayers.Dense(1, 10)
)
)

ps, st = ExplicitFluxLayers.setup(MersenneTwister(0), model)

x = rand(Float32, 128, 1024)

# using Zygote
@btime Zygote.gradient(p -> sum(ExplicitFluxLayers.apply(model, x, p, st)[1]), ps)[1] # 7.858 ms (2002 allocations: 26.95 MiB)

# using Yota
@btime Yota.grad(p -> sum(ExplicitFluxLayers.apply(model, x, p, st)[1]), ps)[2][2] #
```

## Implemented Layers

These layers have the same API as their Flux counterparts.

* `Chain`
* `Dense`
* `Conv`
* `BatchNorm`
* `WeightNorm`
* `Parallel`
* `SkipConnection`
* `MaxPool`, `MeanPool`
* `Chain`, `Parallel`, `SkipConnection`
* `Dense`, `Diagonal`
* `Conv`, `MaxPool`, `MeanPool`
* `BatchNorm`, `WeightNorm`
* `ReshapeLayer`, `SelectDim`, `FlattenLayer`, `NoOpLayer`, `WrappedFunction`

## TODOs
Expand Down
2 changes: 1 addition & 1 deletion src/ExplicitFluxLayers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ExplicitFluxLayers

using Statistics, NNlib, CUDA, Random, Setfield, ChainRulesCore, Octavian, LinearAlgebra
using Statistics, NNlib, CUDA, Random, Setfield, ChainRulesCore, Octavian, LinearAlgebra, FillArrays
import NNlibCUDA: batchnorm
import Flux
import Flux:
Expand Down
18 changes: 17 additions & 1 deletion src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,23 @@ function ChainRulesCore.rrule(
Z = X * Y
function sparse_matmul_pullback(Δ)
Δ = unthunk(Δ)
return NoTangent(), _project* Y', X), X.mat' * Δ
return NoTangent(), _project* Y', X), X.mat' * Δ
end
return Z, sparse_matmul_pullback
end

# Fast Matmul
function ChainRulesCore.rrule(
::typeof(fast_matmul!), C::AbstractVecOrMat{T}, A::AbstractMatrix{T}, B::AbstractVecOrMat{T}
) where {T}
fast_matmul!(C, A, B)
function fast_matmul!_pullback(Δ)
Δ = unthunk(Δ)
return NoTangent(), Δ, fast_matmul(Δ, B'), fast_matmul(A', Δ)
end
function fast_matmul!_pullback::FillArrays.Fill)
Δ = Array(unthunk(Δ))
return NoTangent(), Δ, fast_matmul(Δ, B'), fast_matmul(A', Δ)
end
return C, fast_matmul!_pullback
end
50 changes: 24 additions & 26 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,26 @@ function Dense(in_dims::Int, out_dims::Int, λ=identity; initW=glorot_uniform, i
return Dense{bias,typeof(λ),typeof(initW),typeof(initb)}(λ, in_dims, out_dims, initW, initb)
end

function initialparameters(rng::AbstractRNG, d::Dense{true})
return (weight=d.initW(rng, d.out_dims, d.in_dims), bias=d.initb(rng, d.out_dims, 1))
function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias}
if bias
return (weight=d.initW(rng, d.out_dims, d.in_dims), bias=d.initb(rng, d.out_dims, 1))
else
return (weight=d.initW(rng, d.out_dims, d.in_dims),)
end
end
initialparameters(rng::AbstractRNG, d::Dense{false}) = (weight=d.initW(rng, d.out_dims, d.in_dims),)

parameterlength(d::Dense{true}) = d.out_dims * (d.in_dims + 1)
parameterlength(d::Dense{false}) = d.out_dims * d.in_dims
parameterlength(d::Dense{bias}) where {bias} = bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims
statelength(d::Dense) = 0

Base.@pure function (d::Dense)(x::AbstractArray, ps::NamedTuple, st::NamedTuple)
y, st = d(reshape(x, size(x, 1), :), ps, st)
return reshape(y, :, size(x)[2:end]...), st
end

Base.@pure function (d::Dense{false})(x::AbstractVecOrMat, ps::NamedTuple, st::NamedTuple)
return d.λ.(fast_matmul(ps.weight, x)), st
end

Base.@pure function (d::Dense{true})(x::AbstractMatrix, ps::NamedTuple, st::NamedTuple)
return d.λ.(fast_matmul(ps.weight, x) .+ ps.bias), st
end

Base.@pure function (d::Dense{true})(x::AbstractVector, ps::NamedTuple, st::NamedTuple)
return d.λ.(fast_matmul(ps.weight, x) .+ ps.bias[:]), st
Base.@pure function (d::Dense{bias})(
x::AbstractArray{T,N}, ps::NamedTuple, st::NamedTuple
) where {bias,T,N}
if bias
b = N == 1 ? ps.bias[:] : b = ps.bias
return d.λ.(fast_matmul(ps.weight, x) .+ b), st
else
return d.λ.(fast_matmul(ps.weight, x)), st
end
end

## Diagonal
Expand Down Expand Up @@ -274,10 +270,12 @@ parameterlength(d::Diagonal{true}) = 2 * d.dims
parameterlength(d::Diagonal{false}) = d.dims
statelength(d::Diagonal) = 0

Base.@pure function (d::Diagonal{false})(x::AbstractVecOrMat, ps::NamedTuple, st::NamedTuple)
return d.λ.(ps.weight .* x), st
end

Base.@pure function (d::Diagonal{true})(x::AbstractVecOrMat, ps::NamedTuple, st::NamedTuple)
return d.λ.(ps.weight .* x .+ ps.bias), st
Base.@pure function (d::Diagonal{bias})(
x::AbstractVecOrMat, ps::NamedTuple, st::NamedTuple
) where {bias}
if bias
return d.λ.(ps.weight .* x .+ ps.bias), st
else
return d.λ.(ps.weight .* x), st
end
end
88 changes: 57 additions & 31 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct Conv{N,M,F1,F2} <: AbstractExplicitLayer
struct Conv{N,bias,cdims,M,F1,F2} <: AbstractExplicitLayer
λ::F1
in_chs::Int
out_chs::Int
Expand All @@ -7,15 +7,15 @@ struct Conv{N,M,F1,F2} <: AbstractExplicitLayer
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
groups::Int
bias::Bool
initW::F2
end

function Conv(
k::NTuple{N,Integer},
ch::Pair{<:Integer,<:Integer},
λ=identity;
init=glorot_uniform,
input_size::Union{Nothing,NTuple{N,Integer}}=nothing,
initW=glorot_uniform,
stride=1,
pad=0,
dilation=1,
Expand All @@ -25,91 +25,117 @@ function Conv(
stride = expand(Val(N), stride)
dilation = expand(Val(N), dilation)
pad = calc_padding(Conv, pad, k, dilation, stride)
return Conv(λ, first(ch), last(ch), k, stride, pad, dilation, groups, bias, init)
λ = NNlib.fast_act(λ)
cdims = if input_size === nothing
nothing
else
DenseConvDims(
(input_size..., first(ch), 1),
(k..., ch...);
stride=stride,
padding=pad,
dilation=dilation,
groups=groups,
)
end
return Conv{N,bias,cdims,length(pad),typeof(λ),typeof(initW)}(
λ, first(ch), last(ch), k, stride, pad, dilation, groups, initW
)
end

function initialparameters(rng::AbstractRNG, c::Conv{N}) where {N}
function initialparameters(rng::AbstractRNG, c::Conv{N,bias}) where {N,bias}
initW(args...) = c.initW(rng, args...)
weight = convfilter(c.kernel_size, c.in_chs => c.out_chs; init=initW, groups=c.groups)
return (c.bias ? (weight=weight, bias=zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : (weight = weight,))
return (bias ? (weight=weight, bias=zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : (weight=weight,))
end

parameterlength(c::Conv) = prod(c.kernel_size) * c.in_chs * c.out_chs + (c.bias ? c.out_chs : 0)
parameterlength(c::Conv{N,bias}) where {N,bias} = prod(c.kernel_size) * c.in_chs * c.out_chs + (bias ? c.out_chs : 0)

Base.@pure function (c::Conv)(x::AbstractArray, ps::NamedTuple, st::NamedTuple)
λ = NNlib.fast_act(c.λ, x)
cdims = DenseConvDims(x, ps.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
if c.bias
return λ.(conv(x, ps.weight, cdims) .+ ps.bias), st
Base.@pure function (c::Conv{N,bias,C})(x::AbstractArray, ps::NamedTuple, st::NamedTuple) where {N,bias,C}
cdims = if C === nothing
DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
else
C
end
if bias
return c.λ.(conv(x, ps.weight, cdims) .+ ps.bias), st
# FIXME: Needs https://github.com/FluxML/NNlibCUDA.jl/pull/45 to be merged
# return conv_bias_act(x, ps.weight, cdims, ps.bias, λ), st
else
return λ.(conv(x, ps.weight, cdims)), st
return c.λ.(conv(x, ps.weight, cdims)), st
end
end

function Base.show(io::IO, l::Conv)
print(io, "Conv(", l.kernel_size)
print(io, ", ", l.in_chs, " => ", l.out_chs)
_print_conv_opt(io, l)
print(io, ")")
return print(io, ")")
end

function _print_conv_opt(io::IO, l)
function _print_conv_opt(io::IO, l::Conv{bias}) where {bias}
l.λ == identity || print(io, ", ", l.λ)
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride))
all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation))
(l.groups == 1) || print(io, ", groups=", l.groups)
(l.bias == false) && print(io, ", bias=false")
return (bias == false) && print(io, ", bias=false")
end

struct MaxPool{N,M} <: AbstractExplicitLayer
struct MaxPool{N,M,pdims} <: AbstractExplicitLayer
k::NTuple{N,Int}
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end

function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
function MaxPool(k::NTuple{N,Integer}; pad=0, stride=k, input_size::Union{Nothing,NTuple{N,Integer}}=nothing) where {N}
stride = expand(Val(N), stride)
pad = calc_padding(MaxPool, pad, k, 1, stride)
return MaxPool(k, pad, stride)
pdims = if input_size === nothing
nothing
else
PoolDims((input_size..., first(ch), 1), k; stride=stride, padding=pad, dilation=dilation)
end
return MaxPool{N,length(pad),pdims}(k, pad, stride)
end

Base.@pure function (m::MaxPool)(x, ::NamedTuple, st::NamedTuple)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
Base.@pure function (m::MaxPool{N,M,P})(x, ::NamedTuple, st::NamedTuple) where {N,M,P}
pdims = P === nothing ? PoolDims(x, m.k; padding=m.pad, stride=m.stride) : P
return maxpool(x, pdims), st
end

function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k)
all(==(0), m.pad) || print(io, ", pad=", _maybetuple_string(m.pad))
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
print(io, ")")
return print(io, ")")
end


struct MeanPool{N,M} <: AbstractExplicitLayer
struct MeanPool{N,M,pdims} <: AbstractExplicitLayer
k::NTuple{N,Int}
pad::NTuple{M,Int}
stride::NTuple{N,Int}
end

function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
function MeanPool(k::NTuple{N,Integer}; pad=0, stride=k, input_size::Union{Nothing,NTuple{N,Integer}}=nothing) where {N}
stride = expand(Val(N), stride)
pad = calc_padding(MeanPool, pad, k, 1, stride)
return MeanPool(k, pad, stride)
pdims = if input_size === nothing
nothing
else
PoolDims((input_size..., first(ch), 1), k; stride=stride, padding=pad, dilation=dilation)
end
return MeanPool{N,length(pad),pdims}(k, pad, stride)
end

Base.@pure function (m::MeanPool)(x, ::NamedTuple, st::NamedTuple)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
Base.@pure function (m::MeanPool{N,M,P})(x, ::NamedTuple, st::NamedTuple) where {N,M,P}
pdims = P === nothing ? PoolDims(x, m.k; padding=m.pad, stride=m.stride) : P
return meanpool(x, pdims), st
end

function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k)
all(==(0), m.pad) || print(io, ", pad=", _maybetuple_string(m.pad))
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
print(io, ")")
return print(io, ")")
end

Loading

0 comments on commit bf42a79

Please sign in to comment.