From 9260fc5af6a4df512ce984eee6a72e385fcf5f75 Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Thu, 1 Jun 2017 16:34:55 +0200 Subject: [PATCH] non-recurrent decoders use @tensor --- neuralmonkey/decoders/ctc_decoder.py | 72 +++++---- neuralmonkey/decoders/multi_decoder.py | 148 ------------------- neuralmonkey/decoders/sequence_classifier.py | 92 +++++++----- neuralmonkey/decoders/sequence_labeler.py | 18 ++- 4 files changed, 112 insertions(+), 218 deletions(-) delete mode 100644 neuralmonkey/decoders/multi_decoder.py diff --git a/neuralmonkey/decoders/ctc_decoder.py b/neuralmonkey/decoders/ctc_decoder.py index 9666e2267..f3b723f2d 100644 --- a/neuralmonkey/decoders/ctc_decoder.py +++ b/neuralmonkey/decoders/ctc_decoder.py @@ -32,33 +32,41 @@ def __init__(self, self.vocabulary = vocabulary self.data_id = data_id - self._merge_repeated_targets = merge_repeated_targets - self._merge_repeated_outputs = merge_repeated_outputs - self._beam_width = beam_width - - with self.use_scope(): - self.train_targets = tf.sparse_placeholder(tf.int32, - name="targets") - - self.train_mode = tf.placeholder(tf.bool, name="train_mode") - - # encoder.states_mask is batch-major - self._input_lengths = tf.reduce_sum( - tf.to_int32(self.encoder.states_mask), 1) - - if beam_width == 1: - decoded, _ = tf.nn.ctc_greedy_decoder( - inputs=self._logits, sequence_length=self._input_lengths, - merge_repeated=self._merge_repeated_outputs) - else: - decoded, _ = tf.nn.ctc_beam_search_decoder( - inputs=self._logits, sequence_length=self._input_lengths, - beam_width=self._beam_width, - merge_repeated=self._merge_repeated_outputs) - - self.decoded = tf.sparse_tensor_to_dense( - tf.sparse_transpose(decoded[0]), - default_value=self.vocabulary.get_word_index(END_TOKEN)) + self.merge_repeated_targets = merge_repeated_targets + self.merge_repeated_outputs = merge_repeated_outputs + self.beam_width = beam_width + # pylint: enable=too-many-arguments + + # pylint: disable=no-self-use + @tensor + def train_targets(self) -> tf.Tensor: + return tf.sparse_placeholder(tf.int32, name="targets") + + @tensor + def train_mode(self) -> tf.Tensor: + return tf.placeholder(tf.bool, name="train_mode") + # pylint: disable=no-self-use + + @tensor + def input_lengths(self) -> tf.Tensor: + # encoder.states_mask is batch-major + return tf.reduce_sum(tf.to_int32(self.encoder.states_mask), 1) + + @tensor + def decoded(self) -> tf.Tensor: + if self.beam_width == 1: + decoded, _ = tf.nn.ctc_greedy_decoder( + inputs=self.logits, sequence_length=self.input_lengths, + merge_repeated=self.merge_repeated_outputs) + else: + decoded, _ = tf.nn.ctc_beam_search_decoder( + inputs=self.logits, sequence_length=self.input_lengths, + beam_width=self.beam_width, + merge_repeated=self.merge_repeated_outputs) + + return tf.sparse_tensor_to_dense( + tf.sparse_transpose(decoded[0]), + default_value=self.vocabulary.get_word_index(END_TOKEN)) @property def train_loss(self) -> tf.Tensor: @@ -71,15 +79,15 @@ def runtime_loss(self) -> tf.Tensor: @tensor def cost(self) -> tf.Tensor: loss = tf.nn.ctc_loss( - labels=self.train_targets, inputs=self._logits, - sequence_length=self._input_lengths, - preprocess_collapse_repeated=self._merge_repeated_targets, - ctc_merge_repeated=self._merge_repeated_outputs) + labels=self.train_targets, inputs=self.logits, + sequence_length=self.input_lengths, + preprocess_collapse_repeated=self.merge_repeated_targets, + ctc_merge_repeated=self.merge_repeated_outputs) return tf.reduce_sum(loss) @tensor - def _logits(self) -> tf.Tensor: + def logits(self) -> tf.Tensor: vocabulary_size = len(self.vocabulary) encoder_states = self.encoder.hidden_states diff --git a/neuralmonkey/decoders/multi_decoder.py b/neuralmonkey/decoders/multi_decoder.py deleted file mode 100644 index 04d43433d..000000000 --- a/neuralmonkey/decoders/multi_decoder.py +++ /dev/null @@ -1,148 +0,0 @@ -import tensorflow as tf -import numpy as np - -from neuralmonkey.vocabulary import PAD_TOKEN -from neuralmonkey.logging import log -from neuralmonkey.dataset import Dataset - - -class MultiDecoder(object): - """The MultiDecoder class wraps a several child decoders into - one parent encoder. The Neural Monkey architecture requires the model to - have only one decoder, so this class can be used when more than one output - sequence should be generated (i.e. multi-task learning). - - The multi decoder object composes of one main decoder and an arbitrary - number of additional decoders, called 'regularization' decoders. - - The reason for this division is that during validation, we need to report - a single score of the model as a whole, and based on this score, the - training process decides whether to save the model variables or not. - - So if the task is translation with POS tagging of the source sentence, the - main decoder should be the decoder that generates the target sentence, - whereas the sequence labeler used for POS tagging should be included in the - regularization decoders list. - - During training, the multi decoder works in the following way: According to - the value of the ``_input_selector`` placeholder, the loss corresponds to - one of the child decoders (so in multi-task learning, the weights in each - batch are updated with respect only to one sub-task). It is therefore a - good practice to alternate between batches of different task. This is - because we often do not have the training data that cover all tasks in one - corpus. - - """ - - def __init__(self, main_decoder, regularization_decoders): - """Create a new instance of the multi-decoder. - - Arguments: - main_decoder: The decoder that corresponds to the output which - we want at runtime. - - regularization_decoders: A list of the decoders among which the - multidecoder will switch. - """ - self.main_decoder = main_decoder - - self.regularization_decoders = regularization_decoders - - self._training_decoders = [main_decoder] + regularization_decoders - self._decoder_costs = tf.concat([tf.expand_dims(d.cost, 0) - for d in self._training_decoders], 0) - - self._scheduled_decoder = 0 - self._input_selector = tf.placeholder(tf.float32, - [len(self._training_decoders)], - name="input_decoder_selector") - - log("MultiDecoder initialized.") - - def all_decoded(self): - return [d.decoded for d in self._training_decoders] - - @property - def cost(self): - # Without specifying dimension, returns a scalar. - return tf.reduce_sum(self._decoder_costs * self._input_selector) - - @property - def train_loss(self): - return self.cost - - # The other @properties transparently point to self.main_encoder, because - # they are only used when we want to get the decoded outputs. - - @property - def vocabulary_size(self): - return self.main_decoder.vocabulary_size - - @property - def learning_step(self): - # Maybe this should come from the current training decoder? - # return self._training_decoders[self._scheduled_decoder].learning_step - return self.main_decoder.learning_step - - @property - def runtime_loss(self): - return self.main_decoder.runtime_loss - - @property - def decoded(self): - return self.main_decoder.decoded - - @property - def vocabulary(self): - return self.main_decoder.vocabulary - - @property - def data_id(self): - return self.main_decoder.data_id - - def feed_dict(self, dataset, train=False): - """Populate the feed dictionary for the decoder object - - Decoder placeholders: - ``input_selector``: the index of the child decoder used for - computing the loss - """ - # Two options: - # - call feed_dict only on the currently selected child decoder - # - call feed_dict preventatively everywhere (with dummy data) - - # First option: - # fd = self.decoders[self._scheduled_decoder].feed_dict( - # dataset, train=train) - # First option does not seem to work, so the second option - # will have to do. - - # Second option: - # (This is a fallback plan in canse TensorFlow requires us to fill in - # the data placeholders for the nodes that are hidden behind a zero - # through self.input_selector.) - - # pylint: disable=invalid-name - # fd stands for feed_dict - fd = {} - for i, decoder in enumerate(self._training_decoders): - if i == self._scheduled_decoder: - fd_i = decoder.feed_dict(dataset, train=train) - else: - # serie is a generator of lists of words (i.e. sentences) - serie = [[PAD_TOKEN] for _ in range(len(dataset))] - dummy_dataset = Dataset("dummy", {decoder.data_id: serie}, {}) - fd_i = decoder.feed_dict(dummy_dataset, train=train) - fd.update(fd_i) - - # We now need to set the value of our input_selector placeholder - # as well. - input_selector_value = np.zeros(len(self._training_decoders)) - input_selector_value[self._scheduled_decoder] = 1 - fd[self._input_selector] = input_selector_value - - # Schedule update - self._scheduled_decoder = ((self._scheduled_decoder + 1) - % len(self._training_decoders)) - - return fd diff --git a/neuralmonkey/decoders/sequence_classifier.py b/neuralmonkey/decoders/sequence_classifier.py index 8f900c081..9167e7705 100644 --- a/neuralmonkey/decoders/sequence_classifier.py +++ b/neuralmonkey/decoders/sequence_classifier.py @@ -6,9 +6,7 @@ from neuralmonkey.vocabulary import Vocabulary from neuralmonkey.model.model_part import ModelPart, FeedDict from neuralmonkey.nn.mlp import MultilayerPerceptron - - -# pylint: disable=too-many-instance-attributes +from neuralmonkey.decorators import tensor class SequenceClassifier(ModelPart): @@ -53,36 +51,62 @@ def __init__(self, self.dropout_keep_prob = dropout_keep_prob self.max_output_len = 1 - with self.use_scope(): - self.train_mode = tf.placeholder(tf.bool, name="train_mode") - self.learning_step = tf.get_variable( - "learning_step", [], trainable=False, - initializer=tf.constant_initializer(0)) - - self.gt_inputs = [tf.placeholder( - tf.int32, shape=[None], name="targets")] - mlp_input = tf.concat([enc.encoded for enc in encoders], 1) - mlp = MultilayerPerceptron( - mlp_input, layers, self.dropout_keep_prob, len(vocabulary), - activation_fn=self.activation_fn, train_mode=self.train_mode) - - self.loss_with_gt_ins = tf.reduce_mean( - tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=mlp.logits, labels=self.gt_inputs[0])) - self.loss_with_decoded_ins = self.loss_with_gt_ins - self.cost = self.loss_with_gt_ins - - self.decoded_seq = [mlp.classification] - self.decoded_logits = [mlp.logits] - self.runtime_logprobs = [tf.nn.log_softmax(mlp.logits)] - - tf.summary.scalar( - 'val_optimization_cost', self.cost, - collections=["summary_val"]) - tf.summary.scalar( - 'train_optimization_cost', - self.cost, collections=["summary_train"]) - # pylint: enable=too-many-arguments + tf.summary.scalar( + 'train_optimization_cost', + self.cost, collections=["summary_train"]) +# pylint: enable=too-many-arguments + + # pylint: disable=no-self-use + @tensor + def train_mode(self) -> tf.Tensor: + return tf.placeholder(tf.bool, name="train_mode") + + @tensor + def gt_inputs(self) -> List[tf.Tensor]: + return [tf.placeholder(tf.int32, shape=[None], name="targets")] + # pylint: enable=no-self-use + + @tensor + def _mlp(self) -> MultilayerPerceptron: + mlp_input = tf.concat([enc.encoded for enc in self.encoders], 1) + return MultilayerPerceptron( + mlp_input, self.layers, + self.dropout_keep_prob, len(self.vocabulary), + activation_fn=self.activation_fn, train_mode=self.train_mode) + + @tensor + def loss_with_gt_ins(self) -> tf.Tensor: + # pylint: disable=no-member,unsubscriptable-object + return tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=self._mlp.logits, labels=self.gt_inputs[0])) + # pylint: enable=no-member,unsubscriptable-object + + @property + def loss_with_decoded_ins(self) -> tf.Tensor: + return self.loss_with_gt_ins + + @property + def cost(self) -> tf.Tensor: + return self.loss_with_gt_ins + + @tensor + def decoded_seq(self) -> List[tf.Tensor]: + # pylint: disable=no-member + return [self._mlp.classification] + # pylint: enable=no-member + + @tensor + def decoded_logits(self) -> List[tf.Tensor]: + # pylint: disable=no-member + return [self._mlp.logits] + # pylint: enable=no-member + + @tensor + def runtime_logprobs(self) -> List[tf.Tensor]: + # pylint: disable=no-member + return [tf.nn.log_softmax(self._mlp.logits)] + # pylint: enable=no-member @property def train_loss(self): @@ -108,7 +132,9 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: label_tensors, _ = self.vocabulary.sentences_to_tensor( sentences_list, self.max_output_len) + # pylint: disable=unsubscriptable-object fd[self.gt_inputs[0]] = label_tensors[0] + # pylint: enable=unsubscriptable-object fd[self.train_mode] = train diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 7a7e62cd5..e3e0c3f29 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -30,13 +30,21 @@ def __init__(self, self.rnn_size = self.encoder.rnn_size * 2 self.max_output_len = self.encoder.max_input_len - self.train_targets = tf.placeholder(tf.int32, shape=[None, None], - name="labeler_targets") + # pylint: disable=no-self-use + @tensor + def train_targets(self) -> tf.Tensor: + return tf.placeholder(tf.int32, shape=[None, None], + name="labeler_targets") - self.train_weights = tf.placeholder(tf.float32, shape=[None, None], - name="labeler_padding_weights") + @tensor + def train_weights(self) -> tf.Tensor: + return tf.placeholder(tf.float32, shape=[None, None], + name="labeler_padding_weights") - self.train_mode = tf.placeholder(tf.bool, name="train_mode") + @tensor + def train_mode(self) -> tf.Tensor: + return tf.placeholder(tf.bool, name="train_mode") + # pylint: enable=no-self-use @property def train_loss(self) -> tf.Tensor: