Skip to content

Commit

Permalink
Merge pull request #117 from kundajelab/add_sigmoid_tanh
Browse files Browse the repository at this point in the history
Added support for sigmoid (at intermediate layers) and tanh activations
  • Loading branch information
AvantiShri authored Nov 11, 2020
2 parents 87bb578 + 9f7022c commit fa4957b
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ matrix:
env: KERAS_BACKEND=tensorflow TF_VERSION=1.10.1 KERAS_VERSION=2.2
- python: 3.6
env: KERAS_BACKEND=tensorflow TF_VERSION=1.10.1 KERAS_VERSION=2.2.4
- python: 3.6
env: KERAS_BACKEND=tensorflow TF_VERSION=1.14.0 KERAS_VERSION=2.2.4

notifications:
email: true
Expand Down
7 changes: 7 additions & 0 deletions deeplift/conversion/kerasapi_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
relu='relu',
prelu='prelu',
sigmoid='sigmoid',
tanh='tanh',
softmax='softmax',
linear='linear')

Expand Down Expand Up @@ -74,6 +75,11 @@ def sigmoid_conversion(name, verbose, nonlinear_mxts_mode, **kwargs):
nonlinear_mxts_mode=nonlinear_mxts_mode)]


def tanh_conversion(name, verbose, nonlinear_mxts_mode, **kwargs):
return [layers.activations.Tanh(name=name, verbose=verbose,
nonlinear_mxts_mode=nonlinear_mxts_mode)]


def softmax_conversion(name, verbose, nonlinear_mxts_mode, **kwargs):
return [layers.activations.Softmax(name=name, verbose=verbose,
nonlinear_mxts_mode=nonlinear_mxts_mode)]
Expand Down Expand Up @@ -298,6 +304,7 @@ def activation_to_conversion_function(activation_name):
ActivationTypes.linear: linear_conversion,
ActivationTypes.relu: relu_conversion,
ActivationTypes.sigmoid: sigmoid_conversion,
ActivationTypes.tanh: tanh_conversion,
ActivationTypes.softmax: softmax_conversion
}
return activation_dict[activation_name.lower()]
Expand Down
21 changes: 14 additions & 7 deletions deeplift/layers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,20 @@ def _build_activation_vars(self, input_act_vars):
return tf.nn.sigmoid(input_act_vars)

def _get_gradient_at_activation(self, activation_vars):
if (self.verbose == True):
print("Heads-up: I assume sigmoid is the output layer, "
"not an intermediate one; if it's an intermediate layer "
"then please bug me and I will implement the grad func")
return 0.0 #punting; not implemented for tensorflow yet.
#This shouldn't be needed unless you
#have hidden-unit sigmoid activations
#derivative: https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
out_act = self._build_activation_vars(activation_vars)
return out_act*(1-out_act)


class Tanh(Activation):

def _build_activation_vars(self, input_act_vars):
return tf.nn.tanh(input_act_vars)

def _get_gradient_at_activation(self, activation_vars):
#derivative: https://blogs.cuit.columbia.edu/zp2130/derivative_of_tanh_function/
out_act = self._build_activation_vars(activation_vars)
return 1 - (out_act*out_act)


class Softmax(Activation):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Implements the methods in "Learning Important Features Through Propagating Activation Differences" by Shrikumar, Greenside & Kundaje, as well as other commonly-used methods such as gradients, guided backprop and integrated gradients. See https://github.com/kundajelab/deeplift for documentation and FAQ.
""",
url='https://github.com/kundajelab/deeplift',
version='0.6.12.0',
version='0.6.13.0',
packages=['deeplift',
'deeplift.layers', 'deeplift.visualization',
'deeplift.conversion'],
Expand Down
91 changes: 91 additions & 0 deletions tests/conversion/sequential/test_sigmoid_tanh_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import unittest
from unittest import skip
import sys
import os
import numpy as np
from deeplift.conversion import kerasapi_conversion as kc
import deeplift.layers as layers
from deeplift.layers import NonlinearMxtsMode
from deeplift.util import compile_func
import tensorflow as tf
import keras
from keras import models
from keras import backend as K


class TestConvolutionalModel(unittest.TestCase):


def setUp(self):
self.inp = (np.random.randn(10*10*51)
.reshape(10,10,51)).transpose(0,2,1)
self.keras_model = keras.models.Sequential()
self.keras_model.add(keras.layers.InputLayer((51,10)))
conv_layer = keras.layers.convolutional.Convolution1D(
nb_filter=2, filter_length=4, subsample_length=2,
activation="sigmoid", input_shape=(51,10))
self.keras_model.add(conv_layer)
conv_layer2 = keras.layers.convolutional.Convolution1D(
nb_filter=2, filter_length=4,
activation="tanh", padding="same")
self.keras_model.add(conv_layer2)
self.keras_model.add(keras.layers.pooling.MaxPooling1D(
pool_length=4, stride=2))
self.keras_model.add(keras.layers.pooling.AveragePooling1D(
pool_length=4, stride=2))
self.keras_model.add(keras.layers.Flatten())
self.keras_model.add(keras.layers.Dense(output_dim=1))
self.keras_model.add(keras.layers.core.Activation("sigmoid"))
self.keras_model.compile(loss="mse", optimizer="sgd")
self.keras_output_fprop_func = compile_func(
[self.keras_model.layers[0].input,
K.learning_phase()],
self.keras_model.layers[-1].output)

grad = tf.gradients(tf.reduce_sum(
self.keras_model.layers[-2].output[:,0]),
[self.keras_model.layers[0].input])[0]
self.grad_func = compile_func(
[self.keras_model.layers[0].input,
K.learning_phase()], grad)

self.saved_file_path = "conv1model_validpadding.h5"
if (os.path.isfile(self.saved_file_path)):
os.remove(self.saved_file_path)
self.keras_model.save(self.saved_file_path)

def test_convert_conv1d_model_forward_prop(self):
deeplift_model =\
kc.convert_model_from_saved_files(
self.saved_file_path,
nonlinear_mxts_mode=NonlinearMxtsMode.Gradient)
deeplift_fprop_func = compile_func(
inputs=[deeplift_model.get_layers()[0].get_activation_vars()],
outputs=deeplift_model.get_layers()[-1].get_activation_vars())
np.testing.assert_almost_equal(
deeplift_fprop_func(self.inp),
self.keras_output_fprop_func([self.inp, 0]),
decimal=6)

def test_convert_conv1d_model_compute_scores(self):
deeplift_model =\
kc.convert_model_from_saved_files(self.saved_file_path,
nonlinear_mxts_mode=NonlinearMxtsMode.Gradient)
deeplift_contribs_func = deeplift_model.\
get_target_contribs_func(
find_scores_layer_idx=0,
target_layer_idx=-2)
np.testing.assert_almost_equal(
deeplift_contribs_func(task_idx=0,
input_data_list=[self.inp],
batch_size=10,
progress_update=None),
self.grad_func([self.inp, 0])*self.inp, decimal=6)

def tearDown(self):
if (os.path.isfile(self.saved_file_path)):
os.remove(self.saved_file_path)

0 comments on commit fa4957b

Please sign in to comment.