Skip to content

Commit

Permalink
Merge pull request #349 from ufal/tf1.0
Browse files Browse the repository at this point in the history
Introducing TensorFlow 1.0 branch
  • Loading branch information
jindrahelcl authored Mar 12, 2017
2 parents e5ab2ea + 0850cd3 commit ef855bd
Show file tree
Hide file tree
Showing 50 changed files with 324 additions and 433 deletions.
2 changes: 1 addition & 1 deletion .readthedocs-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ dependencies:
- numpy
- pillow
- git+https://github.com/aflc/pyter@857a1552443f139a3
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.11.0-cp35-cp35m-linux_x86_64.whl
- tensorflow
- sphinx==1.5.1
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ dist: trusty
language: python

env:
global:
- TF=0.11.0-cp35-cp35m
matrix:
- TEST_SUITE=lint
- TEST_SUITE=pycodestyle
Expand All @@ -23,8 +21,9 @@ python:

# commands to install dependencies
before_install:
- sudo apt-get install libtcmalloc-minimal4
- export LD_PRELOAD="/usr/lib/libtcmalloc_minimal.so.4"
- pip install -r requirements.txt
- pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-$TF""-linux_x86_64.whl
- if [ -f tests/$TEST_SUITE""_requirements.txt ]; then pip install -r tests/$TEST_SUITE""_requirements.txt; fi
- if [ -f tests/$TEST_SUITE""_install.sh ]; then tests/$TEST_SUITE""_install.sh; fi

Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.pngmath',
'sphinx.ext.intersphinx'
]

# Add any paths that contain templates here, relative to this directory.
Expand Down
7 changes: 7 additions & 0 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@ and CuDNN installations. Similarly, your ``PATH`` variable should point to the
``bin`` subdirectory of the CUDA installation directory.

You made it! Neural Monkey is now installed!

Note for Ubuntu 14.04 users
***************************

If you get Segmentation fault errors at the very end of the training process,
you can either ignore it, or follow the steps outlined in `this
document <ubuntu1404_fix.html>`_.
4 changes: 2 additions & 2 deletions docs/source/machine_translation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ The following sections are described in more detail in
class=tf_manager.TensorFlowManager
num_threads=4
num_sessions=1
minimize_metric=False
save_n_best=3
.. TUTCHECK exp-nm-mt/translation.ini
Expand All @@ -254,7 +255,6 @@ As for the main configuration section do not forget to add BPE postprocessing:
train_dataset=<train_data>
val_dataset=<val_data>
evaluation=[("series_named_greedy", "target", <bleu>), ("series_named_greedy", "target", evaluators.ter.TER)]
minimize=False
batch_size=80
runners_batch_size=256
epochs=10
Expand All @@ -277,7 +277,7 @@ As for the evaluation, you need to create ``translation_run.ini``:
[main]
test_datasets=[<eval_data>]
[bpe_preprocess]
class=processors.bpe.BPEPreprocessor
merge_file="exp-nm-mt/data/merge_file.bpe"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ TensorFlow should use, you need to specify a "TensorFlow manager":
class=tf_manager.TensorFlowManager
num_threads=4
num_sessions=1
minimize_metric=True
save_n_best=3
.. TUTCHECK exp-nm-ape/post-edit.ini
Expand All @@ -480,7 +481,6 @@ parameters:
train_dataset=<train_dataset>
val_dataset=<val_dataset>
evaluation=[("greedy_edits", "edits", <bleu>), ("greedy_edits", "edits", evaluators.ter.TER)]
minimize=True
batch_size=128
runners_batch_size=256
epochs=100
Expand Down
57 changes: 57 additions & 0 deletions docs/source/ubuntu1404_fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Fixing segmentation fault on exit on Ubuntu 14.04
=================================================

* On Ufal machines, segfault can be prevented by doing this:

.. code-block:: bash
export LD_PRELOAD=/home/helcl/lib/libtcmalloc_minimal.so.4
bin/neuralmonkey-train tests/vocab.ini
* On machines with ``sudo``, one can do this:

.. code-block:: bash
sudo apt-get install libtcmalloc-minimal4
export LD_PRELOAD="/usr/lib/libtcmalloc_minimal.so.4"
* On machines with neither ``sudo`` nor
``~helcl/lib/libtcmalloc_minimal.so.4``, this is the way to fix segfaulting:

.. code-block:: bash
wget http://archive.ubuntu.com/ubuntu/pool/main/g/google-perftools/google-perftools_2.1.orig.tar.gz
tar xpzvf google-perftools_2.1.orig.tar.gz
cd gperftools-2.1/
./configure --prefix=$HOME
make
make install
if the compilation crashes on the need of the ``libunwind`` library (as did for
me), do this:

.. code-block:: bash
wget http://download.savannah.gnu.org/releases/libunwind/libunwind-0.99-beta.tar.gz
tar xpzvf libunwind-0.99-beta.tar.gz
cd libunwind-0.99-beta/
./configure --prefix=$HOME
make
make install
if, by any chance, compilation of this crashes on something like: ``error:
'longjmp' aliased to undefined symbol '_longjmp'``, replace the ``make`` call
with ``make CFLAGS+=-U_FORTIFY_SOURCE`` command.

Then, in ``$HOME/share`` directory, create file ``config.site`` like this:

.. code-block:: bash
cat << EOF > $HOME/share/config.site
CPPFLAGS=-I$HOME/include
LDFLAGS=-L$HOME/lib
EOF
and then redo the configure-make-make install mantra from gperftools. Finally,
set the ``LD_PRELOAD`` environment variable to point to
``$HOME/lib/libtcmalloc_minimal.4.so``.
28 changes: 14 additions & 14 deletions neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ def decode(rnn_outputs):

_, self.train_logits = decode(train_rnn_outputs)

train_targets = tf.unpack(self.train_inputs)
train_targets = tf.transpose(self.train_inputs)

self.train_loss = tf.nn.seq2seq.sequence_loss(
self.train_logits, train_targets,
tf.unpack(self.train_padding), len(self.vocabulary))
self.train_loss = tf.contrib.seq2seq.sequence_loss(
tf.stack(self.train_logits, 1), train_targets,
tf.transpose(self.train_padding))
self.cost = self.train_loss

self.train_logprobs = [tf.nn.log_softmax(l)
Expand All @@ -217,9 +217,9 @@ def decode(rnn_outputs):
self.decoded, self.runtime_logits = decode(
self.runtime_rnn_outputs)

self.runtime_loss = tf.nn.seq2seq.sequence_loss(
self.runtime_logits, train_targets,
tf.unpack(self.train_padding), len(self.vocabulary))
self.runtime_loss = tf.contrib.seq2seq.sequence_loss(
tf.stack(self.runtime_logits, 1), train_targets,
tf.transpose(self.train_padding))

self.runtime_logprobs = [tf.nn.log_softmax(l)
for l in self.runtime_logits]
Expand Down Expand Up @@ -306,11 +306,11 @@ def _logit_function(self, state: tf.Tensor) -> tf.Tensor:
state = dropout(state, self.dropout_keep_prob, self.train_mode)
return tf.matmul(state, self.decoding_w) + self.decoding_b

def _get_rnn_cell(self) -> tf.nn.rnn_cell.RNNCell:
def _get_rnn_cell(self) -> tf.contrib.rnn.RNNCell:
if self._rnn_cell == 'GRU':
return tf.nn.rnn_cell.GRUCell(self.rnn_size)
return tf.contrib.rnn.GRUCell(self.rnn_size)
elif self._rnn_cell == 'LSTM':
return tf.nn.rnn_cell.LSTMCell(self.rnn_size)
return tf.contrib.rnn.LSTMCell(self.rnn_size)
else:
raise ValueError("Unknown RNN cell: {}".format(self._rnn_cell))

Expand Down Expand Up @@ -355,7 +355,7 @@ def _attention_decoder(
state = self.initial_state
elif self._rnn_cell == 'LSTM':
# pylint: disable=redefined-variable-type
state = tf.nn.rnn_cell.LSTMStateTuple(
state = tf.contrib.rnn.LSTMStateTuple(
self.initial_state, self.initial_state)
# pylint: enable=redefined-variable-type
else:
Expand Down Expand Up @@ -423,12 +423,12 @@ def _visualize_attention(self):

for i, a in enumerate(att_objects):
alignments = tf.expand_dims(tf.transpose(
tf.pack(a.attentions_in_time), perm=[1, 2, 0]), -1)
tf.stack(a.attentions_in_time), perm=[1, 2, 0]), -1)

tf.image_summary(
tf.summary.image(
"attention_{}".format(i), alignments,
collections=["summary_val_plots"],
max_images=256)
max_outputs=256)

def feed_dict(self, dataset: Dataset, train: bool=False) -> FeedDict:
"""Populate the feed dictionary for the decoder object
Expand Down
4 changes: 2 additions & 2 deletions neuralmonkey/decoders/encoder_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def func(train_mode: tf.Tensor,
" of encoder projection")

with tf.variable_scope("encoders_projection") as scope:
encoded_concat = tf.concat(1, [e.encoded for e in encoders])
encoded_concat = tf.concat([e.encoded for e in encoders], 1)
encoded_concat = dropout(
encoded_concat, dropout_keep_prob, train_mode)

Expand Down Expand Up @@ -90,7 +90,7 @@ def concat_encoder_projection(
assert rnn_size == sum(e.encoded.get_shape()[1].value
for e in encoders)

encoded_concat = tf.concat(1, [e.encoded for e in encoders])
encoded_concat = tf.concat([e.encoded for e in encoders], 1)

# pylint: disable=no-member
log("The inferred rnn_size of this encoder projection will be {}"
Expand Down
4 changes: 2 additions & 2 deletions neuralmonkey/decoders/multi_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(self, main_decoder, regularization_decoders):
self.regularization_decoders = regularization_decoders

self._training_decoders = [main_decoder] + regularization_decoders
self._decoder_costs = tf.concat(0, [tf.expand_dims(d.cost, 0)
for d in self._training_decoders])
self._decoder_costs = tf.concat([tf.expand_dims(d.cost, 0)
for d in self._training_decoders], 0)

self._scheduled_decoder = 0
self._input_selector = tf.placeholder(tf.float32,
Expand Down
4 changes: 2 additions & 2 deletions neuralmonkey/decoders/output_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def no_deep_output(prev_state, prev_output, ctx_tensors):
Returns:
This function returns the concatenation of all its inputs.
"""
return tf.concat(1, [prev_state, prev_output] + ctx_tensors)
return tf.concat([prev_state, prev_output] + ctx_tensors, 1)


def maxout_output(maxout_size):
Expand Down Expand Up @@ -63,7 +63,7 @@ def mlp_output(layer_sizes, dropout_plc=None, activation=tf.tanh):
activation: The activation function to use in each layer.
"""
def _projection(prev_state, prev_output, ctx_tensors):
mlp_input = tf.concat(1, [prev_state, prev_output] + ctx_tensors)
mlp_input = tf.concat([prev_state, prev_output] + ctx_tensors, 1)

return multilayer_projection(mlp_input, layer_sizes,
activation=activation,
Expand Down
8 changes: 4 additions & 4 deletions neuralmonkey/decoders/sequence_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ def __init__(self,
tf.placeholder(tf.float32, name="dropout_plc")
self.gt_inputs = [tf.placeholder(
tf.int32, shape=[None], name="targets")]
mlp_input = tf.concat(1, [enc.encoded for enc in encoders])
mlp_input = tf.concat([enc.encoded for enc in encoders], 1)
mlp = MultilayerPerceptron(
mlp_input, layers, self.dropout_placeholder, len(vocabulary),
activation_fn=self.activation_fn)

self.loss_with_gt_ins = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
mlp.logits, self.gt_inputs[0]))
logits=mlp.logits, labels=self.gt_inputs[0]))
self.loss_with_decoded_ins = self.loss_with_gt_ins
self.cost = self.loss_with_gt_ins

self.decoded_seq = [mlp.classification]
self.decoded_logits = [mlp.logits]
self.runtime_logprobs = [tf.nn.log_softmax(mlp.logits)]

tf.scalar_summary(
tf.summary.scalar(
'val_optimization_cost', self.cost,
collections=["summary_val"])
tf.scalar_summary(
tf.summary.scalar(
'train_optimization_cost',
self.cost, collections=["summary_train"])
# pylint: enable=too-many-arguments
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/decoders/sequence_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def logits(self) -> tf.Tensor:
biases = tf.get_variable(
name="state_to_word_b",
shape=[vocabulary_size],
initializer=tf.zeros_initializer)
initializer=tf.zeros_initializer())

weights_direct = tf.get_variable(
name="emb_to_word_W",
Expand Down
8 changes: 4 additions & 4 deletions neuralmonkey/decoders/word_alignment_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def __init__(self,
_, self.train_loss = self._make_decoder(runtime_mode=False)
self.decoded, self.runtime_loss = self._make_decoder(runtime_mode=True)

tf.scalar_summary("alignment_train_xent", self.train_loss,
tf.summary.scalar("alignment_train_xent", self.train_loss,
collections=["summary_train"])

def _make_decoder(self, runtime_mode=False):
attn_obj = self.decoder.get_attention_object(self.encoder,
runtime_mode)

alignment_logits = tf.pack(attn_obj.logits_in_time,
name="alignment_logits")
alignment_logits = tf.stack(attn_obj.logits_in_time,
name="alignment_logits")

if runtime_mode:
# make batch_size the first dimension
Expand All @@ -56,7 +56,7 @@ def _make_decoder(self, runtime_mode=False):
alignment = None

xent = tf.nn.softmax_cross_entropy_with_logits(
alignment_logits, self.alignment_target)
labels=self.alignment_target, logits=alignment_logits)
loss = tf.reduce_sum(xent * self.decoder.train_padding)

return alignment, loss
Expand Down
21 changes: 10 additions & 11 deletions neuralmonkey/encoders/factored_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from neuralmonkey.model.model_part import ModelPart
from neuralmonkey.encoders.attentive import Attentive
from neuralmonkey.logging import log
from neuralmonkey.nn.bidirectional_rnn_layer import BidirectionalRNNLayer
from neuralmonkey.vocabulary import Vocabulary


Expand Down Expand Up @@ -72,7 +71,7 @@ def _attention_tensor(self):

def _get_rnn_cell(self):
"""Return the RNN cell for the encoder"""
return tf.nn.rnn_cell.GRUCell(self.rnn_size)
return tf.contrib.rnn.GRUCell(self.rnn_size)

def _get_birnn_cells(self):
"""Return forward and backward RNN cells for the encoder"""
Expand Down Expand Up @@ -138,25 +137,25 @@ def _create_encoder_graph(self):
# factors is a 2D list of embeddings of dims [factor-type, time-step]
# by doing zip(*factors), we get a list of (factor-type) embedding
# tuples indexed by the time step
concatenated_factors = [tf.concat(1, related_factors)
concatenated_factors = [tf.concat(related_factors, 1)
for related_factors in zip(*factors)]
assert_shape(concatenated_factors[0],
[None, sum(self.embedding_sizes)])
forward_gru, backward_gru = self._get_birnn_cells()

bidi_layer = BidirectionalRNNLayer(forward_gru, backward_gru,
concatenated_factors,
sentence_lengths)
stacked_factors = tf.stack(concatenated_factors, 1)

self.outputs_bidi = bidi_layer.outputs_bidi
self.encoded = bidi_layer.encoded
self.outputs_bidi, encoded_tup = tf.nn.bidirectional_dynamic_rnn(
forward_gru, backward_gru, stacked_factors,
sentence_lengths, dtype=tf.float32)

self.__attention_tensor = tf.concat(1, [tf.expand_dims(o, 1)
for o in self.outputs_bidi])
self.encoded = tf.concat(encoded_tup, 1)

self.__attention_tensor = tf.concat(self.outputs_bidi, 2)
self.__attention_tensor = tf.nn.dropout(self.__attention_tensor,
self.dropout_placeholder)
self.__attention_mask = tf.concat(
1, [tf.expand_dims(w, 1) for w in self.padding_weights])
[tf.expand_dims(w, 1) for w in self.padding_weights], 1)

# pylint: disable=too-many-locals
def feed_dict(self, dataset, train=False):
Expand Down
Loading

0 comments on commit ef855bd

Please sign in to comment.