From 6358ea8fd8c38b1f9b85bc100918b2d2b1e9b398 Mon Sep 17 00:00:00 2001 From: Bernhard Schuster Date: Wed, 30 Sep 2020 11:45:09 +0200 Subject: [PATCH] refactor: use different LayerOps trait requirements per framework --- Cargo.lock | 355 +++++++++++++++++++- coaster/src/plugin.rs | 2 +- juice/src/layer.rs | 35 +- juice/src/layers/common/linear.rs | 16 +- juice/src/layers/container/sequential.rs | 12 +- juice/src/layers/loss/mean_squared_error.rs | 4 +- juice/src/solver/mod.rs | 22 +- juice/src/solvers/mod.rs | 4 +- juice/src/solvers/sgd/mod.rs | 2 +- juice/src/solvers/sgd/momentum.rs | 2 +- juice/src/util.rs | 137 +++++--- 11 files changed, 485 insertions(+), 106 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad7ea5b8a..6b75c276b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7" +[[package]] +name = "base64" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" + [[package]] name = "bindgen" version = "0.54.0" @@ -443,12 +449,27 @@ dependencies = [ "strsim 0.9.3", ] +[[package]] +name = "dtoa" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134951f4028bdadb9b84baf4232681efbf277da25144b9b0ad65df75946c422b" + [[package]] name = "either" version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +[[package]] +name = "encoding_rs" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a51b8cf747471cb9499b6d59e59b0444f4c90eba8968c4e44874e92b5b64ace2" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum_primitive" version = "0.1.1" @@ -471,6 +492,44 @@ dependencies = [ "termcolor", ] +[[package]] +name = "example-mnist-classification" +version = "0.0.1" +dependencies = [ + "coaster", + "coaster-nn", + "csv", + "docopt", + "env_logger", + "flate2", + "futures", + "futures-util", + "greenglas", + "hyper", + "hyper-rustls", + "juice", + "juice-utils", + "log", + "mnist", + "serde", + "timeit", + "tokio", +] + +[[package]] +name = "example-rnn-regression" +version = "0.0.1" +dependencies = [ + "coaster", + "coaster-nn", + "csv", + "docopt", + "env_logger", + "greenglas", + "juice", + "serde", +] + [[package]] name = "filetime" version = "0.2.10" @@ -501,6 +560,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "fuchsia-cprng" version = "0.1.1" @@ -783,6 +857,30 @@ dependencies = [ "webpki", ] +[[package]] +name = "hyper-tls" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d979acc56dcb5b8dddba3917601745e877576475aa046df3226eabdecef78eed" +dependencies = [ + "bytes", + "hyper", + "native-tls", + "tokio", + "tokio-tls", +] + +[[package]] +name = "idna" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e2673c30ee86b5b96a9cb52ad15718aa1f966f5ab9ad54a8b95d5ca33120a9" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "image" version = "0.23.6" @@ -819,6 +917,12 @@ dependencies = [ "libc", ] +[[package]] +name = "ipnet" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47be2f14c678be2fdcab04ab1171db51b2762ce6f0a8ee87c8dd4a04ed216135" + [[package]] name = "itoa" version = "0.4.6" @@ -862,9 +966,10 @@ dependencies = [ ] [[package]] -name = "juice-examples" -version = "0.1.1" +name = "juice-utils" +version = "0.0.1" dependencies = [ + "bytes", "coaster", "coaster-nn", "csv", @@ -874,11 +979,10 @@ dependencies = [ "futures", "futures-util", "greenglas", - "hyper", - "hyper-rustls", "juice", "log", "mnist", + "reqwest", "serde", "timeit", "tokio", @@ -957,6 +1061,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d947cbb889ed21c2a84be6ffbaebf5b4e0f4340638cba0444907e38b56be084" +[[package]] +name = "matches" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" + [[package]] name = "maybe-uninit" version = "2.0.0" @@ -978,6 +1088,22 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mime" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" + +[[package]] +name = "mime_guess" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2684d4c2e97d99848d30b324b00c8fcc7e5c897b7cbb5819b09e7c90e8baf212" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "miniz_oxide" version = "0.3.7" @@ -1049,6 +1175,24 @@ version = "0.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2983372caf4480544083767bf2d27defafe32af49ab4df3a0b7fc90793a3664" +[[package]] +name = "native-tls" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b0d88c06fe90d5ee94048ba40409ef1d9315d86f6f38c2efdaad4fb50c58b2d" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "net2" version = "0.2.34" @@ -1231,12 +1375,39 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b631f7e854af39a1739f401cf34a8a013dfe09eac4fa4dba91e9768bd28168d" +[[package]] +name = "openssl" +version = "0.10.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d575eff3665419f9b83678ff2815858ad9d11567e082f5ac1814baba4e2bcb4" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "lazy_static", + "libc", + "openssl-sys", +] + [[package]] name = "openssl-probe" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77af24da69f9d9341038eba93a073b1fdaaa1b788221b00a69bce9e762cb32de" +[[package]] +name = "openssl-sys" +version = "0.9.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a842db4709b604f0fe5d1170ae3565899be2ad3d9cbc72dedc789ac0511f78de" +dependencies = [ + "autocfg", + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking_lot" version = "0.10.2" @@ -1267,6 +1438,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "percent-encoding" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" + [[package]] name = "pin-project" version = "0.4.22" @@ -1546,6 +1723,50 @@ version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi 0.3.8", +] + +[[package]] +name = "reqwest" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9eaa17ac5d7b838b7503d118fa16ad88f440498bf9ffe5424e621f93190d61e" +dependencies = [ + "base64 0.12.3", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "hyper-tls", + "ipnet", + "js-sys", + "lazy_static", + "log", + "mime", + "mime_guess", + "native-tls", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_urlencoded", + "tokio", + "tokio-tls", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "ring" version = "0.16.15" @@ -1567,7 +1788,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bc8af4bda8e1ff4932523b94d3dd20ee30a87232323eda55903ffd71d2fb017" dependencies = [ - "base64", + "base64 0.11.0", "blake2b_simd", "constant_time_eq", "crossbeam-utils", @@ -1622,7 +1843,7 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0d4a31f5d68413404705d6982529b0e11a9aacd4839d1d6222ee3b8cb4015e1" dependencies = [ - "base64", + "base64 0.11.0", "log", "ring", "sct", @@ -1754,6 +1975,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ec5d77e2d4c73717816afac02670d5c4f534ea95ed430442cad02e7a6e32c97" +dependencies = [ + "dtoa", + "itoa", + "serde", + "url", +] + [[package]] name = "serial_test" version = "0.4.0" @@ -1835,6 +2068,20 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "tempfile" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9" +dependencies = [ + "cfg-if", + "libc", + "rand 0.7.3", + "redox_syscall", + "remove_dir_all", + "winapi 0.3.8", +] + [[package]] name = "term" version = "0.6.1" @@ -1913,6 +2160,12 @@ dependencies = [ "time", ] +[[package]] +name = "tinyvec" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "238ce071d267c5710f9d31451efec16c5ee22de34df17cc05e56cbc92e967117" + [[package]] name = "tokio" version = "0.2.21" @@ -1926,8 +2179,21 @@ dependencies = [ "lazy_static", "memchr", "mio", + "num_cpus", "pin-project-lite", "slab", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3acc6aa564495a0f2e1d59fab677cd7f81a19994cfc7f3ad0e64301560389" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -1942,6 +2208,16 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a70f4fcd7b3b24fb194f837560168208f669ca8cb70d0c4b862944452396343" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-util" version = "0.3.1" @@ -1968,6 +2244,33 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e604eb7b43c06650e854be16a2a03155743d3752dd1c943f6829e26b7a36e382" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f2bd0c6468a8230e1db229cff8029217cf623c767ea5d60bfbd42729ea54d5" +dependencies = [ + "matches", +] + +[[package]] +name = "unicode-normalization" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fb19cf769fa8c6a80a162df694621ebeb4dafb606470b2b2fce0be40a98a977" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-width" version = "0.1.7" @@ -1986,6 +2289,23 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "url" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "829d4a8476c35c9bf0bbce5a3b23f4106f79728039b726d292bb93bc106787cb" +dependencies = [ + "idna", + "matches", + "percent-encoding", +] + +[[package]] +name = "vcpkg" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6454029bf181f092ad1b853286f23e2c507d8e8194d01d92da4a55c274a5508c" + [[package]] name = "vec_map" version = "0.8.2" @@ -2021,6 +2341,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c2dc4aa152834bc334f506c1a06b866416a8b6697d5c9f75b9a689c8486def0" dependencies = [ "cfg-if", + "serde", + "serde_json", "wasm-bindgen-macro", ] @@ -2039,6 +2361,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64487204d863f109eb77e8462189d111f27cb5712cc9fdb3461297a76963a2f6" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.63" @@ -2140,6 +2474,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "winreg" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0120db82e8a1e0b9fb3345a539c478767c0048d842860994d96113d5b667bd69" +dependencies = [ + "winapi 0.3.8", +] + [[package]] name = "ws2_32-sys" version = "0.2.1" diff --git a/coaster/src/plugin.rs b/coaster/src/plugin.rs index 2df39a778..4d8d4dccc 100644 --- a/coaster/src/plugin.rs +++ b/coaster/src/plugin.rs @@ -22,7 +22,7 @@ //! //! Extending the Backend with your own Plugin is a straight forward process. //! For now we recommend that you take a look at the general code structure of [Coaster-BLAS][coaster-blas] -//! or its documentation. Let us now about your Plugin on the Gitter chat, we are happy to feature +//! or its documentation. Let us know about your Plugin on the Gitter chat, we are happy to feature //! your Coaster Plugin on the README. //! //! [program]: ../program/index.html diff --git a/juice/src/layer.rs b/juice/src/layer.rs index dad9fbcce..09fbcb1bd 100644 --- a/juice/src/layer.rs +++ b/juice/src/layer.rs @@ -762,10 +762,10 @@ impl Layer { /// # } /// # } /// ``` - pub fn load + 'static, P: AsRef>( - backend: Rc, + pub fn load::F, f32> + 'static, P: AsRef>( + backend: Rc, path: P, - ) -> io::Result> { + ) -> io::Result> { let path = path.as_ref(); let ref mut file = File::open(path)?; let mut reader = BufReader::new(file); @@ -951,7 +951,7 @@ impl<'a, B: IBackend> CapnpWrite<'a> for Layer { } } -impl + crate::coblas::plugin::Copy + 'static> Layer { +impl Layer where B: IBackend + LayerOps<::F,f32> + crate::coblas::plugin::Copy + 'static { /// Creates a new Layer from a [LayerConfig][1]. /// [1]: ./struct.LayerConfig.html pub fn from_config(backend: Rc, config: &LayerConfig) -> Layer { @@ -984,7 +984,7 @@ impl + crate::coblas::plugin::Copy + 'static> L backend: backend.clone(), - worker: Layer::::worker_from_config(backend, &cfg), + worker: ::F, f32>>::layer_from_config::(backend, &cfg), config: cfg, }; layer.expose_inputs(); @@ -992,31 +992,6 @@ impl + crate::coblas::plugin::Copy + 'static> L layer } - - /// Helper for [from_config] to match a [LayerType][2] to its [implementation][3]. - /// [1]: #method.from_config - /// [2]: ./enum.LayerType.html - /// [3]: ../layers/index.html - fn worker_from_config(backend: Rc, config: &LayerConfig) -> Box> { - match config.layer_type.clone() { - LayerType::Convolution(layer_config) => Box::new(Convolution::from_config(&layer_config)), - LayerType::Rnn(layer_config) => Box::new(Rnn::from_config(&layer_config)), - LayerType::Linear(layer_config) => Box::new(Linear::from_config(&layer_config)), - LayerType::LogSoftmax => Box::new(LogSoftmax::default()), - LayerType::Pooling(layer_config) => Box::new(Pooling::from_config(&layer_config)), - LayerType::Sequential(layer_config) => Box::new(Sequential::from_config(backend, &layer_config)), - LayerType::Softmax => Box::new(Softmax::default()), - LayerType::ReLU => Box::new(ReLU), - LayerType::TanH => Box::new(TanH), - LayerType::Sigmoid => Box::new(Sigmoid), - LayerType::NegativeLogLikelihood(layer_config) => { - Box::new(NegativeLogLikelihood::from_config(&layer_config)) - } - LayerType::MeanSquaredError => Box::new(MeanSquaredError), - LayerType::Reshape(layer_config) => Box::new(Reshape::from_config(&layer_config)), - LayerType::Dropout(layer_config) => Box::new(Dropout::from_config(&layer_config)), - } - } } /// A Layer in a Neural Network that can handle forward and backward of a computation step. diff --git a/juice/src/layers/common/linear.rs b/juice/src/layers/common/linear.rs index a22352ac2..67b7e5cc4 100644 --- a/juice/src/layers/common/linear.rs +++ b/juice/src/layers/common/linear.rs @@ -22,12 +22,18 @@ use crate::capnp_util::*; use crate::co::backend::IBackend; use crate::co::tensor::SharedTensor; +use crate::coblas::plugin::*; use crate::coblas::transpose::Transpose; use crate::juice_capnp::linear_config as capnp_config; use crate::layer::*; -use crate::util::{native_scalar, ArcLock, LayerOps}; +use crate::util::{native_scalar, ArcLock, LayerOps, Axpby}; use crate::weight::FillerType; + +trait ILinearCalc: Gemm + Axpby + Copy {} + +impl ILinearCalc for T where T: Gemm + Axpby + Copy {} + #[derive(Debug)] /// Linear Layer pub struct Linear { @@ -67,7 +73,7 @@ impl Linear { } } -impl> ILayer for Linear { +impl> ILayer for Linear { fn auto_weight_blobs(&self) -> bool { true } @@ -123,7 +129,7 @@ impl> ILayer for Linear { } } -impl> ComputeOutput for Linear { +impl> ComputeOutput for Linear { /// Basically, x has the shape (k, n) where k is the batch size. Given W with shape (m, n) where /// m is output vector length, we compute the output with the formula xW^T which will give us a /// matrix of size (k, m) with the outputs. @@ -162,7 +168,7 @@ impl> ComputeOutput for Linear { } } -impl> ComputeInputGradient for Linear { +impl> ComputeInputGradient for Linear { /// Since we have row vectors instead of columns, xW^T = (Wx^T)^T. Take the derivative with /// respect to x^T (gives us a column vector of dimension (n, 1)), we get d((Wx^T)^T)/d(x^T) = /// W^T of dims (n, m). In backpropagation with column vectors, we would take W^T * output_grad, @@ -192,7 +198,7 @@ impl> ComputeInputGradient for Linear { } } -impl> ComputeParametersGradient for Linear { +impl> ComputeParametersGradient for Linear { fn compute_parameters_gradient( &self, backend: &B, diff --git a/juice/src/layers/container/sequential.rs b/juice/src/layers/container/sequential.rs index a6ebef859..0a648fa73 100644 --- a/juice/src/layers/container/sequential.rs +++ b/juice/src/layers/container/sequential.rs @@ -13,7 +13,7 @@ use std::sync::{Arc, RwLock}; #[derive(Debug)] /// Sequential Layer -pub struct Sequential> { +pub struct Sequential::F,f32>> { layers: Vec>>, input_tensor_names: Vec, @@ -26,7 +26,7 @@ pub struct Sequential> { registry: HashMap>, ArcLock>)>, } -impl + 'static> Sequential { +impl::F,f32> + 'static> Sequential { /// Create a empty Sequential container layer. pub fn empty() -> Sequential { Sequential { @@ -219,7 +219,7 @@ impl + 'static> Sequential { } } -impl + 'static> ILayer for Sequential { +impl::F,f32> + 'static> ILayer for Sequential { fn is_container(&self) -> bool { true } @@ -344,7 +344,7 @@ impl + 'static> ILayer for Sequential { } } -impl + 'static> ComputeOutput for Sequential { +impl::F,f32> + 'static> ComputeOutput for Sequential { // we are overriding `forward` and not calling `compute_output` fn compute_output( &self, @@ -356,7 +356,7 @@ impl + 'static> ComputeOutput for Sequential } } -impl + 'static> ComputeInputGradient for Sequential { +impl::F,f32> + 'static> ComputeInputGradient for Sequential { // we are overriding `backward_input` and not calling `compute_input_gradient` fn compute_input_gradient( &self, @@ -370,7 +370,7 @@ impl + 'static> ComputeInputGradient for Seq } } -impl + 'static> ComputeParametersGradient for Sequential { +impl::F,f32> + 'static> ComputeParametersGradient for Sequential { // we are overriding `backward_parameters` and not calling `compute_parameters_gradient` fn compute_parameters_gradient( &self, diff --git a/juice/src/layers/loss/mean_squared_error.rs b/juice/src/layers/loss/mean_squared_error.rs index 7964dfb6d..86fe46382 100644 --- a/juice/src/layers/loss/mean_squared_error.rs +++ b/juice/src/layers/loss/mean_squared_error.rs @@ -23,7 +23,7 @@ impl MeanSquaredError { } } -impl + Axpby> ILayer for MeanSquaredError { +impl::F,f32> + Axpby> ILayer for MeanSquaredError { fn reshape( &mut self, backend: ::std::rc::Rc, @@ -70,7 +70,7 @@ impl ComputeOutput for MeanSquaredError { } // Calculate a Gradient for Mean Squared Error -impl> ComputeInputGradient for MeanSquaredError { +impl::F,f32>> ComputeInputGradient for MeanSquaredError { fn compute_input_gradient( &self, backend: &B, diff --git a/juice/src/solver/mod.rs b/juice/src/solver/mod.rs index d07fc4628..e7a3a2193 100644 --- a/juice/src/solver/mod.rs +++ b/juice/src/solver/mod.rs @@ -20,7 +20,7 @@ use std::rc::Rc; #[derive(Debug)] /// Solver that optimizes a [Layer][1] with a given objective. /// [1]: ../layer/index.html -pub struct Solver, B: IBackend + LayerOps> { +pub struct Solver, B: IBackend + LayerOps<::F,f32>> { net: Layer, objective: Layer, /// The implementation of the Solver @@ -34,7 +34,7 @@ pub struct Solver, B: IBackend + LayerOps, } -impl + 'static, B: IBackend + LayerOps + 'static> Solver { +impl + 'static, B: IBackend + LayerOps<::F,f32> + 'static> Solver { /// Create Solver from [SolverConfig][1] /// [1]: ./struct.SolverConfig.html /// @@ -56,7 +56,7 @@ impl + 'static, B: IBackend + LayerOps + } } -impl + 'static, B: IBackend + LayerOps + 'static> Solver { +impl + 'static, B: IBackend + LayerOps<::F,f32> + 'static> Solver { fn init(&mut self, backend: Rc) { info!("Initializing solver from configuration"); @@ -112,7 +112,7 @@ impl + 'static, B: IBackend + LayerOps + /// /// See [Solvers][1] /// [1]: ../solvers/index.html -pub trait ISolver> { +pub trait ISolver::F,f32>> { /// Initialize the solver, setting up any network related data. fn init(&mut self, net: &Layer) {} @@ -133,7 +133,7 @@ pub trait ISolver> { fn backend(&self) -> &SolverB; } -impl> ::std::fmt::Debug for dyn ISolver { +impl::F,f32>> ::std::fmt::Debug for dyn ISolver { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { write!(f, "({})", "ILayer") } @@ -216,7 +216,7 @@ pub struct SolverConfig { /// The value should always be between 0 and 1 and dictates how much of the previous /// gradient update will be added to the current one. /// - /// Default: 0 + /// Default: 0.0 pub momentum: f32, } @@ -240,7 +240,7 @@ impl Default for SolverConfig { weight_decay: None, regularization_method: None, - momentum: 0f32, + momentum: 0.0f32, } } } @@ -338,11 +338,11 @@ pub enum SolverKind { impl SolverKind { /// Create a Solver of the specified kind with the supplied SolverConfig. - pub fn with_config + 'static, NetB: IBackend + LayerOps + 'static>( + pub fn with_config + 'static, NetB: IBackend + LayerOps<::F,f32> + 'static>( &self, - backend: Rc, + backend: Rc, config: &SolverConfig, - ) -> Box> { + ) -> Box> { match *self { SolverKind::SGD(sgd) => sgd.with_config(backend, config), } @@ -359,7 +359,7 @@ pub enum SGDKind { impl SGDKind { /// Create a Solver of the specified kind with the supplied SolverConfig. - pub fn with_config + 'static, NetB: IBackend + LayerOps + 'static>( + pub fn with_config + 'static, NetB: IBackend + LayerOps<::F,f32> + 'static>( &self, backend: Rc, config: &SolverConfig, diff --git a/juice/src/solvers/mod.rs b/juice/src/solvers/mod.rs index 08ab06916..b6388a915 100644 --- a/juice/src/solvers/mod.rs +++ b/juice/src/solvers/mod.rs @@ -36,7 +36,7 @@ use crate::layer::*; use crate::solver::*; use crate::util::*; -trait SGDSolver, NetB: IBackend + LayerOps>: ISolver { +trait SGDSolver, NetB: IBackend + LayerOps<::F,f32>>: ISolver { fn compute_update_value( &mut self, config: &SolverConfig, @@ -59,7 +59,7 @@ trait SGDSolver, NetB: IBackend + LayerOps + 'static>(&self, config: &SolverConfig, net: &mut Layer) { + fn clip_gradients::F,f32> + 'static>(&self, config: &SolverConfig, net: &mut Layer) { // skip clipping gradients if SolverConfig.clip_gradients is set to None if let Some(clip_threshold) = config.clip_gradients { let native = native_backend(); diff --git a/juice/src/solvers/sgd/mod.rs b/juice/src/solvers/sgd/mod.rs index e39f22f53..7605385bf 100644 --- a/juice/src/solvers/sgd/mod.rs +++ b/juice/src/solvers/sgd/mod.rs @@ -26,7 +26,7 @@ pub use self::momentum::Momentum; #[macro_export] macro_rules! impl_isolver_sgd { ($t:ty) => { - impl, NetB: IBackend + LayerOps + 'static> ISolver + impl, NetB: IBackend + LayerOps<::F,f32> + 'static> ISolver for $t { /// Initialize the SGD Momentum solver, allocating memory for its history. diff --git a/juice/src/solvers/sgd/momentum.rs b/juice/src/solvers/sgd/momentum.rs index ab396091b..60b03c5db 100644 --- a/juice/src/solvers/sgd/momentum.rs +++ b/juice/src/solvers/sgd/momentum.rs @@ -56,7 +56,7 @@ impl> Momentum { } } -impl, NetB: IBackend + LayerOps + 'static> SGDSolver for Momentum { +impl, NetB: IBackend + LayerOps<::F,f32> + 'static> SGDSolver for Momentum { fn compute_update_value( &mut self, config: &SolverConfig, diff --git a/juice/src/util.rs b/juice/src/util.rs index bfee7acf9..1a201c445 100644 --- a/juice/src/util.rs +++ b/juice/src/util.rs @@ -1,9 +1,12 @@ //! Provides common utility functions +use std::rc::Rc; + use crate::co::frameworks::native::flatbox::FlatBox; use crate::co::prelude::*; use crate::coblas::plugin::*; use crate::conn; +use crate::layer::{LayerType,LayerConfig,ILayer}; use num::traits::{cast, NumCast}; use std::sync::{Arc, RwLock}; @@ -95,49 +98,101 @@ pub trait Axpby: Axpy + Scal { impl + Scal> Axpby for T {} /// Encapsulates all traits required by Solvers. -// pub trait SolverOps : Axpby + Dot + Copy {} -// -// impl + Dot + Copy> SolverOps for T {} -pub trait SolverOps: LayerOps + Axpby + Dot + Copy {} - -impl + Axpby + Dot + Copy> SolverOps for T {} - -/// Encapsulates all traits used in Layers. -pub trait LayerOps: - conn::Convolution - + conn::Rnn - + conn::Pooling - + conn::Relu - + conn::ReluPointwise - + conn::Sigmoid - + conn::SigmoidPointwise - + conn::Tanh - + conn::TanhPointwise - + conn::Softmax - + conn::LogSoftmax - + conn::Dropout - + Gemm - + Axpby - + Copy +pub trait SolverOps: Axpby + Dot + Copy {} + +impl + Dot + Copy> SolverOps for T {} + + +use crate::layers::*; + + +/// Encapsulates all traits used in Layers per `Framework` and data type `F`. +pub trait LayerOps : coblas::plugin::Copy { + /// Helper for [from_config] to match a [LayerType][2] to its [implementation][3]. + /// [1]: #method.from_config + /// [2]: ./enum.LayerType.html + /// [3]: ../layers/index.html + fn layer_from_config>(backend: Rc, config: &LayerConfig) -> Box>; } -impl< - T: conn::Convolution - + conn::Rnn - + conn::Pooling - + conn::Relu - + conn::ReluPointwise - + conn::Sigmoid - + conn::SigmoidPointwise - + conn::Tanh - + conn::TanhPointwise - + conn::Softmax - + conn::LogSoftmax - + conn::Dropout - + Gemm - + Axpby - + Copy, - > LayerOps for T +#[cfg(feature = "native")] +impl LayerOps for T where + T: conn::Convolution + + conn::Rnn + + conn::Pooling + + conn::Relu + + conn::ReluPointwise + + conn::Sigmoid + + conn::SigmoidPointwise + + conn::Tanh + + conn::TanhPointwise + + conn::Softmax + + conn::LogSoftmax + + conn::Dropout + + Gemm + + Axpby + + Copy { + fn layer_from_config>(backend: Rc, config: &LayerConfig) -> Box> { + match config.layer_type { + LayerType::Linear(layer_config) => Box::new(Linear::from_config(&layer_config)), + LayerType::LogSoftmax => Box::new(LogSoftmax::default()), + LayerType::Pooling(layer_config) => Box::new(Pooling::from_config(&layer_config)), + LayerType::Sequential(layer_config) => Box::new(Sequential::from_config(backend, &layer_config)), + LayerType::Softmax => Box::new(Softmax::default()), + LayerType::ReLU => Box::new(ReLU), + LayerType::TanH => Box::new(TanH), + LayerType::Sigmoid => Box::new(Sigmoid), + LayerType::NegativeLogLikelihood(layer_config) => { + Box::new(NegativeLogLikelihood::from_config(&layer_config)) + } + LayerType::MeanSquaredError => Box::new(MeanSquaredError), + LayerType::Reshape(layer_config) => Box::new(Reshape::from_config(&layer_config)), + LayerType::Dropout(layer_config) => Box::new(Dropout::from_config(&layer_config)), + unsupported => panic!("Native does not support the requested layer type {:?}", unsupported), + } + } +} + +#[cfg(feature = "cuda")] +impl LayerOps for T where + T: conn::Convolution + + conn::Rnn + + conn::Pooling + + conn::Relu + + conn::ReluPointwise + + conn::Sigmoid + + conn::SigmoidPointwise + + conn::Tanh + + conn::TanhPointwise + + conn::Softmax + + conn::LogSoftmax + + conn::Dropout + + Gemm + + Axpby + + Copy +{ + fn layer_from_config>(backend: Rc, config: &LayerConfig) -> Box> { + // fn layer_from_config(backend: Rc>, config: &LayerConfig) -> Box>> { + match config.layer_type { + LayerType::Convolution(layer_config) => Box::new(Convolution::from_config(layer_config)), + LayerType::Rnn(layer_config) => Box::new(Rnn::from_config(&layer_config)), + LayerType::Linear(layer_config) => Box::new(Linear::from_config(&layer_config)), + LayerType::LogSoftmax => Box::new(LogSoftmax::default()), + LayerType::Pooling(layer_config) => Box::new(Pooling::from_config(&layer_config)), + LayerType::Sequential(layer_config) => Box::new(Sequential::from_config(backend, &layer_config)), + LayerType::Softmax => Box::new(Softmax::default()), + LayerType::ReLU => Box::new(ReLU), + LayerType::TanH => Box::new(TanH), + LayerType::Sigmoid => Box::new(Sigmoid), + LayerType::NegativeLogLikelihood(layer_config) => { + Box::new(NegativeLogLikelihood::from_config(&layer_config)) + } + LayerType::MeanSquaredError => Box::new(MeanSquaredError), + LayerType::Reshape(layer_config) => Box::new(Reshape::from_config(&layer_config)), + LayerType::Dropout(layer_config) => Box::new(Dropout::from_config(&layer_config)), + } + } + }