Skip to content

Commit

Permalink
Merge pull request #459 from ufal/decoders_refactor
Browse files Browse the repository at this point in the history
`@tensor` to non-recurrent decoders
  • Loading branch information
jlibovicky authored Jun 5, 2017
2 parents 436e05a + 9260fc5 commit 8326820
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 218 deletions.
72 changes: 40 additions & 32 deletions neuralmonkey/decoders/ctc_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,41 @@ def __init__(self,
self.vocabulary = vocabulary
self.data_id = data_id

self._merge_repeated_targets = merge_repeated_targets
self._merge_repeated_outputs = merge_repeated_outputs
self._beam_width = beam_width

with self.use_scope():
self.train_targets = tf.sparse_placeholder(tf.int32,
name="targets")

self.train_mode = tf.placeholder(tf.bool, name="train_mode")

# encoder.states_mask is batch-major
self._input_lengths = tf.reduce_sum(
tf.to_int32(self.encoder.states_mask), 1)

if beam_width == 1:
decoded, _ = tf.nn.ctc_greedy_decoder(
inputs=self._logits, sequence_length=self._input_lengths,
merge_repeated=self._merge_repeated_outputs)
else:
decoded, _ = tf.nn.ctc_beam_search_decoder(
inputs=self._logits, sequence_length=self._input_lengths,
beam_width=self._beam_width,
merge_repeated=self._merge_repeated_outputs)

self.decoded = tf.sparse_tensor_to_dense(
tf.sparse_transpose(decoded[0]),
default_value=self.vocabulary.get_word_index(END_TOKEN))
self.merge_repeated_targets = merge_repeated_targets
self.merge_repeated_outputs = merge_repeated_outputs
self.beam_width = beam_width
# pylint: enable=too-many-arguments

# pylint: disable=no-self-use
@tensor
def train_targets(self) -> tf.Tensor:
return tf.sparse_placeholder(tf.int32, name="targets")

@tensor
def train_mode(self) -> tf.Tensor:
return tf.placeholder(tf.bool, name="train_mode")
# pylint: disable=no-self-use

@tensor
def input_lengths(self) -> tf.Tensor:
# encoder.states_mask is batch-major
return tf.reduce_sum(tf.to_int32(self.encoder.states_mask), 1)

@tensor
def decoded(self) -> tf.Tensor:
if self.beam_width == 1:
decoded, _ = tf.nn.ctc_greedy_decoder(
inputs=self.logits, sequence_length=self.input_lengths,
merge_repeated=self.merge_repeated_outputs)
else:
decoded, _ = tf.nn.ctc_beam_search_decoder(
inputs=self.logits, sequence_length=self.input_lengths,
beam_width=self.beam_width,
merge_repeated=self.merge_repeated_outputs)

return tf.sparse_tensor_to_dense(
tf.sparse_transpose(decoded[0]),
default_value=self.vocabulary.get_word_index(END_TOKEN))

@property
def train_loss(self) -> tf.Tensor:
Expand All @@ -71,15 +79,15 @@ def runtime_loss(self) -> tf.Tensor:
@tensor
def cost(self) -> tf.Tensor:
loss = tf.nn.ctc_loss(
labels=self.train_targets, inputs=self._logits,
sequence_length=self._input_lengths,
preprocess_collapse_repeated=self._merge_repeated_targets,
ctc_merge_repeated=self._merge_repeated_outputs)
labels=self.train_targets, inputs=self.logits,
sequence_length=self.input_lengths,
preprocess_collapse_repeated=self.merge_repeated_targets,
ctc_merge_repeated=self.merge_repeated_outputs)

return tf.reduce_sum(loss)

@tensor
def _logits(self) -> tf.Tensor:
def logits(self) -> tf.Tensor:
vocabulary_size = len(self.vocabulary)

encoder_states = self.encoder.hidden_states
Expand Down
148 changes: 0 additions & 148 deletions neuralmonkey/decoders/multi_decoder.py

This file was deleted.

92 changes: 59 additions & 33 deletions neuralmonkey/decoders/sequence_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from neuralmonkey.vocabulary import Vocabulary
from neuralmonkey.model.model_part import ModelPart, FeedDict
from neuralmonkey.nn.mlp import MultilayerPerceptron


# pylint: disable=too-many-instance-attributes
from neuralmonkey.decorators import tensor


class SequenceClassifier(ModelPart):
Expand Down Expand Up @@ -53,36 +51,62 @@ def __init__(self,
self.dropout_keep_prob = dropout_keep_prob
self.max_output_len = 1

with self.use_scope():
self.train_mode = tf.placeholder(tf.bool, name="train_mode")
self.learning_step = tf.get_variable(
"learning_step", [], trainable=False,
initializer=tf.constant_initializer(0))

self.gt_inputs = [tf.placeholder(
tf.int32, shape=[None], name="targets")]
mlp_input = tf.concat([enc.encoded for enc in encoders], 1)
mlp = MultilayerPerceptron(
mlp_input, layers, self.dropout_keep_prob, len(vocabulary),
activation_fn=self.activation_fn, train_mode=self.train_mode)

self.loss_with_gt_ins = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=mlp.logits, labels=self.gt_inputs[0]))
self.loss_with_decoded_ins = self.loss_with_gt_ins
self.cost = self.loss_with_gt_ins

self.decoded_seq = [mlp.classification]
self.decoded_logits = [mlp.logits]
self.runtime_logprobs = [tf.nn.log_softmax(mlp.logits)]

tf.summary.scalar(
'val_optimization_cost', self.cost,
collections=["summary_val"])
tf.summary.scalar(
'train_optimization_cost',
self.cost, collections=["summary_train"])
# pylint: enable=too-many-arguments
tf.summary.scalar(
'train_optimization_cost',
self.cost, collections=["summary_train"])
# pylint: enable=too-many-arguments

# pylint: disable=no-self-use
@tensor
def train_mode(self) -> tf.Tensor:
return tf.placeholder(tf.bool, name="train_mode")

@tensor
def gt_inputs(self) -> List[tf.Tensor]:
return [tf.placeholder(tf.int32, shape=[None], name="targets")]
# pylint: enable=no-self-use

@tensor
def _mlp(self) -> MultilayerPerceptron:
mlp_input = tf.concat([enc.encoded for enc in self.encoders], 1)
return MultilayerPerceptron(
mlp_input, self.layers,
self.dropout_keep_prob, len(self.vocabulary),
activation_fn=self.activation_fn, train_mode=self.train_mode)

@tensor
def loss_with_gt_ins(self) -> tf.Tensor:
# pylint: disable=no-member,unsubscriptable-object
return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self._mlp.logits, labels=self.gt_inputs[0]))
# pylint: enable=no-member,unsubscriptable-object

@property
def loss_with_decoded_ins(self) -> tf.Tensor:
return self.loss_with_gt_ins

@property
def cost(self) -> tf.Tensor:
return self.loss_with_gt_ins

@tensor
def decoded_seq(self) -> List[tf.Tensor]:
# pylint: disable=no-member
return [self._mlp.classification]
# pylint: enable=no-member

@tensor
def decoded_logits(self) -> List[tf.Tensor]:
# pylint: disable=no-member
return [self._mlp.logits]
# pylint: enable=no-member

@tensor
def runtime_logprobs(self) -> List[tf.Tensor]:
# pylint: disable=no-member
return [tf.nn.log_softmax(self._mlp.logits)]
# pylint: enable=no-member

@property
def train_loss(self):
Expand All @@ -108,7 +132,9 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
label_tensors, _ = self.vocabulary.sentences_to_tensor(
sentences_list, self.max_output_len)

# pylint: disable=unsubscriptable-object
fd[self.gt_inputs[0]] = label_tensors[0]
# pylint: enable=unsubscriptable-object

fd[self.train_mode] = train

Expand Down
18 changes: 13 additions & 5 deletions neuralmonkey/decoders/sequence_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,21 @@ def __init__(self,
self.rnn_size = self.encoder.rnn_size * 2
self.max_output_len = self.encoder.max_input_len

self.train_targets = tf.placeholder(tf.int32, shape=[None, None],
name="labeler_targets")
# pylint: disable=no-self-use
@tensor
def train_targets(self) -> tf.Tensor:
return tf.placeholder(tf.int32, shape=[None, None],
name="labeler_targets")

self.train_weights = tf.placeholder(tf.float32, shape=[None, None],
name="labeler_padding_weights")
@tensor
def train_weights(self) -> tf.Tensor:
return tf.placeholder(tf.float32, shape=[None, None],
name="labeler_padding_weights")

self.train_mode = tf.placeholder(tf.bool, name="train_mode")
@tensor
def train_mode(self) -> tf.Tensor:
return tf.placeholder(tf.bool, name="train_mode")
# pylint: enable=no-self-use

@property
def train_loss(self) -> tf.Tensor:
Expand Down

0 comments on commit 8326820

Please sign in to comment.