diff --git a/neuralmonkey/decoders/decoder.py b/neuralmonkey/decoders/decoder.py index f94b16951..fcdbc476e 100644 --- a/neuralmonkey/decoders/decoder.py +++ b/neuralmonkey/decoders/decoder.py @@ -1,7 +1,7 @@ # pylint: disable=too-many-lines import math -from typing import (cast, Iterable, List, Callable, Optional, - Any, Tuple, NamedTuple, Union) +from typing import (cast, Iterable, List, Callable, Optional, Any, Tuple, + NamedTuple) import numpy as np import tensorflow as tf @@ -13,19 +13,20 @@ PAD_TOKEN_INDEX) from neuralmonkey.model.model_part import ModelPart, FeedDict from neuralmonkey.model.sequence import EmbeddedSequence -from neuralmonkey.model.stateful import (TemporalStatefulWithOutput, - SpatialStatefulWithOutput) +from neuralmonkey.model.stateful import Stateful from neuralmonkey.logging import log, warn -from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell +from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell, NematusGRUCell from neuralmonkey.nn.utils import dropout from neuralmonkey.decoders.encoder_projection import ( - linear_encoder_projection, concat_encoder_projection, empty_initial_state) + linear_encoder_projection, concat_encoder_projection, empty_initial_state, + EncoderProjection) from neuralmonkey.decoders.output_projection import (OutputProjectionSpec, nonlinear_output) from neuralmonkey.decorators import tensor RNN_CELL_TYPES = { + "NematusGRU": NematusGRUCell, "GRU": OrthoGRUCell, "LSTM": tf.contrib.rnn.LSTMCell } @@ -62,9 +63,7 @@ class Decoder(ModelPart): # pylint: disable=too-many-locals # pylint: disable=too-many-arguments,too-many-branches,too-many-statements def __init__(self, - # TODO only stateful, attention will need temporal or spat. - encoders: List[Union[TemporalStatefulWithOutput, - SpatialStatefulWithOutput]], + encoders: List[Stateful], vocabulary: Vocabulary, data_id: str, name: str, @@ -73,9 +72,7 @@ def __init__(self, rnn_size: int = None, embedding_size: int = None, output_projection: OutputProjectionSpec = None, - encoder_projection: Callable[ - [tf.Tensor, Optional[int], Optional[List[Any]]], - tf.Tensor]=None, + encoder_projection: EncoderProjection = None, attentions: List[BaseAttention] = None, embeddings_source: EmbeddedSequence = None, attention_on_input: bool = True, @@ -164,7 +161,8 @@ def __init__(self, assert self.rnn_size is not None if self._rnn_cell_str not in RNN_CELL_TYPES: - raise ValueError("RNN cell must be a either 'GRU' or 'LSTM'") + raise ValueError("RNN cell must be a either 'GRU', 'LSTM', or " + "'NematusGRU'. Not {}".format(self._rnn_cell_str)) if self.output_projection_spec is None: log("No output projection specified - using tanh projection") @@ -365,7 +363,11 @@ def _get_rnn_cell(self) -> tf.contrib.rnn.RNNCell: return RNN_CELL_TYPES[self._rnn_cell_str](self.rnn_size) def _get_conditional_gru_cell(self) -> tf.contrib.rnn.GRUCell: - return tf.contrib.rnn.GRUCell(self.rnn_size) + if self._rnn_cell_str == "NematusGRU": + return NematusGRUCell( + self.rnn_size, use_state_bias=True, use_input_bias=False) + + return RNN_CELL_TYPES[self._rnn_cell_str](self.rnn_size) def embed_input_symbol(self, *args) -> tf.Tensor: loop_state = LoopState(*args) @@ -403,16 +405,17 @@ def body(*args) -> LoopState: # Run the RNN. cell = self._get_rnn_cell() - if self._rnn_cell_str == "GRU": - cell_output, state = cell(rnn_input, - loop_state.prev_rnn_output) - next_state = state + if self._rnn_cell_str in ["GRU", "NematusGRU"]: + cell_output, next_state = cell( + rnn_input, loop_state.prev_rnn_output) + attns = [ a.attention(cell_output, loop_state.prev_rnn_output, rnn_input, att_loop_state, loop_state.step) for a, att_loop_state in zip( self.attentions, loop_state.attention_loop_states)] + if self.attentions: contexts, att_loop_states = zip(*attns) else: @@ -421,8 +424,9 @@ def body(*args) -> LoopState: if self._conditional_gru: cell_cond = self._get_conditional_gru_cell() cond_input = tf.concat(contexts, -1) - cell_output, state = cell_cond(cond_input, state, - scope="cond_gru_2_cell") + cell_output, next_state = cell_cond( + cond_input, next_state, scope="cond_gru_2_cell") + elif self._rnn_cell_str == "LSTM": prev_state = tf.contrib.rnn.LSTMStateTuple( loop_state.prev_rnn_state, loop_state.prev_rnn_output) diff --git a/neuralmonkey/decoders/encoder_projection.py b/neuralmonkey/decoders/encoder_projection.py index 817da2192..36b55092e 100644 --- a/neuralmonkey/decoders/encoder_projection.py +++ b/neuralmonkey/decoders/encoder_projection.py @@ -2,40 +2,47 @@ This module contains different variants of projection of encoders into the initial state of the decoder. -""" +Encoder projections are specified in the configuration file. Each encoder +projection function has a unified type ``EncoderProjection``, which is a +callable that takes three arguments: + +1. ``train_mode`` -- boolean tensor specifying whether the train mode is on +2. ``rnn_size`` -- the size of the resulting initial state +3. ``encoders`` -- a list of ``Stateful`` objects used as the encoders. -from typing import List, Optional, Callable +To enable further parameterization of encoder projection functions, one can +use higher-order functions. +""" +from typing import List, Callable, cast import tensorflow as tf +from typeguard import check_argument_types -from neuralmonkey.model.stateful import Stateful +from neuralmonkey.model.stateful import Stateful, TemporalStatefulWithOutput from neuralmonkey.nn.utils import dropout from neuralmonkey.logging import log +# pylint: disable=invalid-name +EncoderProjection = Callable[ + [tf.Tensor, int, List[Stateful]], tf.Tensor] +# pylint: enable=invalid-name + + # pylint: disable=unused-argument # The function must conform the API def empty_initial_state(train_mode: tf.Tensor, - rnn_size: Optional[int], + rnn_size: int, encoders: List[Stateful] = None) -> tf.Tensor: - """Return an empty vector. - - Arguments: - train_mode: tf 0-D bool Tensor specifying the training mode (not used) - rnn_size: The size of the resulting vector - encoders: The list of encoders (not used) - """ + """Return an empty vector.""" if rnn_size is None: - raise ValueError("You must supply rnn_size for this type of " - "encoder projection") + raise ValueError( + "You must supply rnn_size for this type of encoder projection") return tf.zeros([rnn_size]) -def linear_encoder_projection( - dropout_keep_prob: float) -> Callable[ - [tf.Tensor, Optional[int], Optional[List[Stateful]]], - tf.Tensor]: +def linear_encoder_projection(dropout_keep_prob: float) -> EncoderProjection: """Return a linear encoder projection. Return a projection function which applies dropout on concatenated @@ -45,61 +52,82 @@ def linear_encoder_projection( Arguments: dropout_keep_prob: The dropout keep probability """ + check_argument_types() + def func(train_mode: tf.Tensor, - rnn_size: Optional[int] = None, - encoders: Optional[List[Stateful]] = None) -> tf.Tensor: - """Linearly project encoders' encoded value. - - Linearly project encoders' encoded value to rnn_size - and apply dropout. - - Arguments: - train_mode: tf 0-D bool Tensor specifying the training mode - rnn_size: The size of the resulting vector - encoders: The list of encoders - """ - if rnn_size is None: - raise ValueError("You must supply rnn_size for this type of " - "encoder projection") + rnn_size: int, + encoders: List[Stateful]) -> tf.Tensor: - if encoders is None or not encoders: - raise ValueError("There must be at least one encoder for this type" - " of encoder projection") + if rnn_size is None: + raise ValueError( + "You must supply rnn_size for this type of encoder projection") - encoded_concat = tf.concat([e.output for e in encoders], 1) - encoded_concat = dropout( - encoded_concat, dropout_keep_prob, train_mode) + en_concat = concat_encoder_projection(train_mode, None, encoders) + en_concat = dropout(en_concat, dropout_keep_prob, train_mode) - return tf.layers.dense(encoded_concat, rnn_size, - name="encoders_projection") + return tf.layers.dense(en_concat, rnn_size, name="encoders_projection") - return func + return cast(EncoderProjection, func) def concat_encoder_projection( train_mode: tf.Tensor, - rnn_size: Optional[int] = None, - encoders: Optional[List[Stateful]] = None) -> tf.Tensor: - """Create the initial state by concatenating the encoders' encoded values. + rnn_size: int = None, + encoders: List[Stateful] = None) -> tf.Tensor: + """Concatenate the encoded values of the encoders.""" - Arguments: - train_mode: tf 0-D bool Tensor specifying the training mode (not used) - rnn_size: The size of the resulting vector (not used) - encoders: The list of encoders - """ if encoders is None or not encoders: raise ValueError("There must be at least one encoder for this type " "of encoder projection") - if rnn_size is not None: - assert rnn_size == sum(e.output.get_shape()[1].value - for e in encoders) - - encoded_concat = tf.concat([e.output for e in encoders], 1) + output_size = sum(e.output.get_shape()[1].value for e in encoders) + if rnn_size is not None and rnn_size != output_size: + raise ValueError("RNN size supplied for concat projection ({}) does " + "not match the size of the concatenated vectors ({})." + .format(rnn_size, output_size)) - # pylint: disable=no-member log("The inferred rnn_size of this encoder projection will be {}" - .format(encoded_concat.get_shape()[1].value)) - # pylint: enable=no-member + .format(output_size)) + encoded_concat = tf.concat([e.output for e in encoders], 1) return encoded_concat + + +def nematus_projection(dropout_keep_prob: float = 1.0) -> EncoderProjection: + """Return encoder projection used in Nematus. + + The initial state is a dense projection with tanh activation computed on + the averaged states of the encoders. Dropout is applied to the means + (before the projection). + + Arguments: + dropout_keep_prob: The dropout keep probability. + """ + check_argument_types() + + def func( + train_mode: tf.Tensor, + rnn_size: int, + encoders: List[TemporalStatefulWithOutput]) -> tf.Tensor: + + if len(encoders) != 1: + raise ValueError("Exactly one encoder required for this type of " + "projection. {} given.".format(len(encoders))) + encoder = encoders[0] + + # shape (batch, time) + masked_sum = tf.reduce_sum( + encoder.temporal_states + * tf.expand_dims(encoder.temporal_mask, 2), 1) + + # shape (batch, 1) + lengths = tf.reduce_sum(encoder.temporal_mask, 1, keep_dims=True) + + means = masked_sum / lengths + means = dropout(means, dropout_keep_prob, train_mode) + + return tf.layers.dense(means, rnn_size, + activation=tf.tanh, + name="encoders_projection") + + return cast(EncoderProjection, func) diff --git a/neuralmonkey/decoders/output_projection.py b/neuralmonkey/decoders/output_projection.py index 49cfe57cc..6e787ea35 100644 --- a/neuralmonkey/decoders/output_projection.py +++ b/neuralmonkey/decoders/output_projection.py @@ -1,10 +1,25 @@ -"""Module with different variants of projection functions for RNN outputs.""" +"""Output Projection Module. +This module contains different variants of projection functions of decoder +outputs into the logit function inputs. + +Output projections are specified in the configuration file. Each output +projection function has a unified type ``OutputProjection``, which is a +callable that takes four arguments and returns a tensor: + +1. ``prev_state`` -- the hidden state of the decoder. +2. ``prev_output`` -- embedding of the previously decoded word (or train input) +3. ``ctx_tensots`` -- a list of context vectors (for each attention object) + +To enable further parameterization of output projection functions, one can +use higher-order functions. +""" from typing import Union, Tuple, List, Callable import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.nn.projection import multilayer_projection, maxout +from neuralmonkey.nn.utils import dropout # pylint: disable=invalid-name @@ -57,6 +72,35 @@ def _projection(prev_state, prev_output, ctx_tensors, train_mode): return _projection, output_size +def nematus_output( + output_size: int, + dropout_keep_prob: float = 1.0) -> Tuple[OutputProjection, int]: + """Apply nonlinear one-hidden-layer deep output. + + Implementation consistent with Nematus. + Can be used instead of (and is in theory equivalent to) nonlinear_output. + + Projects the RNN state, embedding of the previously outputted word, and + concatenation of all context vectors into a shared vector space, sums them + up and apply a hyperbolic tangent activation function. + """ + check_argument_types() + + def _projection(prev_state, prev_output, ctx_tensors, train_mode): + prev_state = dropout(prev_state, dropout_keep_prob, train_mode) + prev_output = dropout(prev_output, dropout_keep_prob, train_mode) + ctx_concat = tf.concat(ctx_tensors, 1) + ctx = dropout(ctx_concat, dropout_keep_prob, train_mode) + + logit_rnn = tf.layers.dense(prev_state, output_size, name="rnn_state") + logit_emb = tf.layers.dense(prev_output, output_size, name="prev_out") + logit_ctx = tf.layers.dense(ctx, output_size, name="context") + + return tf.tanh(logit_rnn + logit_emb + logit_ctx) + + return _projection, output_size + + def nonlinear_output( output_size: int, activation_fn: Callable[[tf.Tensor], tf.Tensor] = tf.tanh diff --git a/neuralmonkey/encoders/recurrent.py b/neuralmonkey/encoders/recurrent.py index 4978b3e8b..31890fbdd 100644 --- a/neuralmonkey/encoders/recurrent.py +++ b/neuralmonkey/encoders/recurrent.py @@ -5,7 +5,7 @@ from neuralmonkey.model.stateful import TemporalStatefulWithOutput from neuralmonkey.model.model_part import ModelPart, FeedDict -from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell +from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell, NematusGRUCell from neuralmonkey.nn.utils import dropout from neuralmonkey.vocabulary import Vocabulary from neuralmonkey.dataset import Dataset @@ -18,6 +18,7 @@ # pylint: enable=invalid-name RNN_CELL_TYPES = { + "NematusGRU": NematusGRUCell, "GRU": OrthoGRUCell, "LSTM": tf.contrib.rnn.LSTMCell } @@ -52,7 +53,8 @@ def __init__(self, raise ValueError("Dropout keep prob must be inside (0,1].") if self.rnn_cell_str not in RNN_CELL_TYPES: - raise ValueError("RNN cell must be a either 'GRU' or 'LSTM'") + raise ValueError("RNN cell must be a either 'GRU', 'LSTM', or " + "'NematusGRU'. Not {}".format(self.rnn_cell_str)) if output_size is not None: if output_size <= 0: @@ -63,7 +65,6 @@ def __init__(self, else: self._project_final_state = False self.output_size = 2 * rnn_size - # pylint: enable=too-many-arguments # pylint: disable=no-self-use @@ -97,7 +98,7 @@ def temporal_states(self) -> tf.Tensor: @tensor def output(self) -> tf.Tensor: # pylint: disable=unsubscriptable-object - if self.rnn_cell_str == "GRU": + if self.rnn_cell_str in ["GRU", "NematusGRU"]: output = tf.concat(self.bidirectional_rnn[1], 1) elif self.rnn_cell_str == "LSTM": # TODO is "h" what we want? diff --git a/neuralmonkey/model/sequence.py b/neuralmonkey/model/sequence.py index 220f9c59b..469a17664 100644 --- a/neuralmonkey/model/sequence.py +++ b/neuralmonkey/model/sequence.py @@ -86,12 +86,15 @@ def lengths(self) -> tf.Tensor: class EmbeddedFactorSequence(Sequence): """A `Sequence` that stores one or more embedded inputs (factors).""" + # pylint: disable=too-many-arguments def __init__(self, name: str, vocabularies: List[Vocabulary], data_ids: List[str], embedding_sizes: List[int], max_length: int = None, + add_start_symbol: bool = False, + add_end_symbol: bool = False, save_checkpoint: str = None, load_checkpoint: str = None) -> None: """Construct a new instance of `EmbeddedFactorSequence`. @@ -109,6 +112,8 @@ def __init__(self, embedding_sizes: A list of integers specifying the size of the embedding vector for each factor max_length: The maximum length of the sequences + add_start_symbol: Includes in the sequence + add_end_symbol: Includes in the sequence save_checkpoint: The save_checkpoint parameter for `ModelPart` load_checkpoint: The load_checkpoint parameter for `ModelPart` """ @@ -120,6 +125,8 @@ def __init__(self, self.vocabulary_sizes = [len(vocab) for vocab in self.vocabularies] self.data_ids = data_ids self.embedding_sizes = embedding_sizes + self.add_start_symbol = add_start_symbol + self.add_end_symbol = add_end_symbol if not (len(self.data_ids) == len(self.vocabularies) @@ -129,6 +136,7 @@ def __init__(self, if any([esize <= 0 for esize in self.embedding_sizes]): raise ValueError("Embedding size must be a positive integer.") + # pylint: enable=too-many-arguments # TODO this should be placed into the abstract embedding class def tb_embedding_visualization(self, logdir: str, @@ -234,7 +242,8 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: factors = dataset.get_series(name) vectors, paddings = vocabulary.sentences_to_tensor( list(factors), self.max_length, pad_to_max_len=False, - train_mode=train) + train_mode=train, add_start_symbol=self.add_start_symbol, + add_end_symbol=self.add_end_symbol) fd[factor_plc] = list(zip(*vectors)) @@ -252,12 +261,15 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: class EmbeddedSequence(EmbeddedFactorSequence): """A sequence of embedded inputs (for a single factor).""" + # pylint: disable=too-many-arguments def __init__(self, name: str, vocabulary: Vocabulary, data_id: str, embedding_size: int, max_length: int = None, + add_start_symbol: bool = False, + add_end_symbol: bool = False, save_checkpoint: str = None, load_checkpoint: str = None) -> None: """Construct a new instance of `EmbeddedSequence`. @@ -270,6 +282,8 @@ def __init__(self, embedding_sizes: An integer that specifies the size of the embedding vector for the sequence data max_length: The maximum length of the sequences + add_start_symbol: Includes in the sequence + add_end_symbol: Includes in the sequence save_checkpoint: The save_checkpoint parameter for `ModelPart` load_checkpoint: The load_checkpoint parameter for `ModelPart` """ @@ -280,8 +294,11 @@ def __init__(self, data_ids=[data_id], embedding_sizes=[embedding_size], max_length=max_length, + add_start_symbol=add_start_symbol, + add_end_symbol=add_end_symbol, save_checkpoint=save_checkpoint, load_checkpoint=load_checkpoint) + # pylint: enable=too-many-arguments # pylint: disable=unsubscriptable-object @property diff --git a/neuralmonkey/nn/ortho_gru_cell.py b/neuralmonkey/nn/ortho_gru_cell.py index 3d0e5c53b..4e74f8823 100644 --- a/neuralmonkey/nn/ortho_gru_cell.py +++ b/neuralmonkey/nn/ortho_gru_cell.py @@ -5,7 +5,65 @@ class OrthoGRUCell(tf.contrib.rnn.GRUCell): """Classic GRU cell but initialized using random orthogonal matrices.""" - def __call__(self, inputs, state, scope=None): - with tf.variable_scope(scope or "OrthoGRUCell") as vscope: - vscope.set_initializer(tf.orthogonal_initializer()) - return super().__call__(inputs, state, vscope) + def __init__(self, + num_units, + activation=None, + reuse=None, + bias_initializer=None): + tf.contrib.rnn.GRUCell.__init__( + self, num_units, activation, reuse, tf.orthogonal_initializer(), + bias_initializer) + + def __call__(self, inputs, state, scope="OrthoGRUCell"): + return tf.contrib.rnn.GRUCell.__call__(self, inputs, state, scope) + + +# Note that tensorflow does not like when the type annotations are present. +class NematusGRUCell(tf.contrib.rnn.GRUCell): + """Nematus implementation of gated recurrent unit cell. + + The main difference is the order in which the gating functions and linear + projections are applied to the hidden state. + + The math is equivalent, in practice there are differences due to float + precision errors. + """ + + def __init__(self, rnn_size, use_state_bias=False, use_input_bias=True): + self.use_state_bias = use_state_bias + self.use_input_bias = use_input_bias + + tf.contrib.rnn.GRUCell.__init__(self, rnn_size) + + def call(self, inputs, state): + """Gated recurrent unit (GRU) with nunits cells.""" + with tf.variable_scope("gates"): + input_to_gates = tf.layers.dense( + inputs, 2 * self._num_units, name="input_proj", + use_bias=self.use_input_bias) + + # Nematus does the orthogonal initialization probably differently + state_to_gates = tf.layers.dense( + state, 2 * self._num_units, + use_bias=self.use_state_bias, + kernel_initializer=tf.orthogonal_initializer(), + name="state_proj") + + gates_input = state_to_gates + input_to_gates + reset, update = tf.split( + tf.sigmoid(gates_input), num_or_size_splits=2, axis=1) + + with tf.variable_scope("candidate"): + input_to_candidate = tf.layers.dense( + inputs, self._num_units, use_bias=self.use_input_bias, + name="input_proj") + + state_to_candidate = tf.layers.dense( + state, self._num_units, use_bias=self.use_state_bias, + name="state_proj") + + candidate = self._activation( + state_to_candidate * reset + input_to_candidate) + + new_state = update * state + (1 - update) * candidate + return new_state, new_state diff --git a/neuralmonkey/vocabulary.py b/neuralmonkey/vocabulary.py index de36afa9f..b05d77cbc 100644 --- a/neuralmonkey/vocabulary.py +++ b/neuralmonkey/vocabulary.py @@ -6,6 +6,7 @@ # pylint: disable=too-many-lines import collections +import json import os import random @@ -110,6 +111,34 @@ def from_wordlist(path: str, return vocabulary +def from_nematus_json(path: str, max_size: int = None, + pad_to_max_size: bool = False) -> "Vocabulary": + """Load vocabulary from Nematus JSON format. + + The JSON format is a flat dictionary that maps words to their index in the + vocabulary. + + Args: + path: Path to the file. + max_size: Maximum vocabulary size including 'unk' and 'eos' symbols, + but not including and symbol. + """ + with open(path, "r", encoding="utf-8") as f_json: + contents = json.load(f_json) + + vocabulary = Vocabulary() + for word in sorted(contents.keys(), key=lambda x: contents[x])[2:max_size]: + vocabulary.add_word(word) + + if pad_to_max_size: + current_length = len(vocabulary) + for i in range(max_size - current_length + 2): # the "2" is ugly HACK + word = "".format(i) + vocabulary.add_word(word) + + return vocabulary + + # pylint: disable=too-many-arguments # helper function, this number of parameters is needed def from_dataset(datasets: List[Dataset], series_ids: List[str], max_size: int, diff --git a/scripts/import_nematus.py b/scripts/import_nematus.py new file mode 100755 index 000000000..7efcc371c --- /dev/null +++ b/scripts/import_nematus.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 +"""Imports nematus model file and convert it into a neural monkey experiment +given a neural monkey configuration file. +""" +from typing import Dict, Tuple, List +import argparse +import json +import os +import numpy as np +import tensorflow as tf + +from neuralmonkey.config.parsing import parse_file +from neuralmonkey.config.builder import build_config +from neuralmonkey.attention.feed_forward import Attention +from neuralmonkey.encoders.recurrent import RecurrentEncoder +from neuralmonkey.decoders.decoder import Decoder +from neuralmonkey.decoders.encoder_projection import nematus_projection +from neuralmonkey.decoders.output_projection import nematus_output +from neuralmonkey.model.sequence import EmbeddedSequence +from neuralmonkey.vocabulary import from_nematus_json +from neuralmonkey.logging import log as _log + + +def log(message: str, color: str = "blue") -> None: + _log(message, color) + + +def check_shape(var1_tf: tf.Variable, var2_np: np.ndarray): + if var1_tf.get_shape().as_list() != list(var2_np.shape): + log("Shapes do not match! Exception will follow.", color="red") + + +# Here come a few functions that fiddle with the Nematus parameters in order to +# fit them to Neural Monkey parameter shapes. +def emb_fix_dim1(variables: List[np.ndarray]) -> np.ndarray: + return emb_fix(variables, dim=1) + + +def emb_fix(variables: List[np.ndarray], dim: int = 0) -> np.ndarray: + """Process nematus tensors with vocabulary dimension. + + Nematus uses only two special symbols, eos and UNK. For embeddings of start + and pad tokens, we use zero vectors, inserted to the correct position in + the parameter matrix. + + Arguments: + variables: the list of variables. Must be of length 1. + dim: The vocabulary dimension. + """ + if len(variables) != 1: + raise ValueError("VocabFix only works with single vars. {} given." + .format(len(variables))) + + if dim != 0 and dim != 1: + raise ValueError("dim can only be 0 or 1. is: {}".format(dim)) + + variable = variables[0] + shape = variable.shape + + # insert start token (hack from nematus - last from vocab - does it work? NO) + # to_insert = np.squeeze(variable[-1] if dim == 0 else variable[:, -1]) + to_insert = np.zeros(shape[1 - dim]) if len(shape) > 1 else 0. + variable = np.insert(variable, 0, to_insert, axis=dim) + + # insert padding token + to_insert = np.zeros(shape[1 - dim]) if len(shape) > 1 else 0. + variable = np.insert(variable, 0, to_insert, axis=dim) + + return variable + + +def sum_vars(variables: List[np.ndarray]) -> np.ndarray: + return sum(variables) + + +def concat_vars(variables: List[np.ndarray]) -> np.ndarray: + return np.concatenate(variables) + + +def squeeze(variables: List[np.ndarray]) -> np.ndarray: + if len(variables) != 1: + raise ValueError("Squeeze only works with single vars. {} given." + .format(len(variables))) + return np.squeeze(variables[0]) + +# pylint: disable=line-too-long +# No point in line wrapping +VARIABLE_MAP = { + "encoder_input/embedding_matrix_0": (["Wemb"], emb_fix), + "decoder/word_embeddings": (["Wemb_dec"], emb_fix), + "decoder/state_to_word_W": (["ff_logit_W"], emb_fix_dim1), + "decoder/state_to_word_b": (["ff_logit_b"], emb_fix), + "encoder/bidirectional_rnn/fw/OrthoGRUCell/gates/state_proj/kernel": (["encoder_U"], None), + "encoder/bidirectional_rnn/fw/OrthoGRUCell/gates/input_proj/kernel": (["encoder_W"], None), + "encoder/bidirectional_rnn/fw/OrthoGRUCell/gates/input_proj/bias": (["encoder_b"], None), + "encoder/bidirectional_rnn/fw/OrthoGRUCell/candidate/state_proj/kernel": (["encoder_Ux"], None), + "encoder/bidirectional_rnn/fw/OrthoGRUCell/candidate/input_proj/kernel": (["encoder_Wx"], None), + "encoder/bidirectional_rnn/fw/OrthoGRUCell/candidate/input_proj/bias": (["encoder_bx"], None), + "encoder/bidirectional_rnn/bw/OrthoGRUCell/gates/state_proj/kernel": (["encoder_r_U"], None), + "encoder/bidirectional_rnn/bw/OrthoGRUCell/gates/input_proj/kernel": (["encoder_r_W"], None), + "encoder/bidirectional_rnn/bw/OrthoGRUCell/gates/input_proj/bias": (["encoder_r_b"], None), + "encoder/bidirectional_rnn/bw/OrthoGRUCell/candidate/state_proj/kernel": (["encoder_r_Ux"], None), + "encoder/bidirectional_rnn/bw/OrthoGRUCell/candidate/input_proj/kernel": (["encoder_r_Wx"], None), + "encoder/bidirectional_rnn/bw/OrthoGRUCell/candidate/input_proj/bias": (["encoder_r_bx"], None), + "decoder/initial_state/encoders_projection/kernel": (["ff_state_W"], None), + "decoder/initial_state/encoders_projection/bias": (["ff_state_b"], None), + "decoder/attention_decoder/OrthoGRUCell/gates/state_proj/kernel": (["decoder_U"], None), + "decoder/attention_decoder/OrthoGRUCell/gates/input_proj/kernel": (["decoder_W"], None), + "decoder/attention_decoder/OrthoGRUCell/gates/input_proj/bias": (["decoder_b"], None), + "decoder/attention_decoder/OrthoGRUCell/candidate/state_proj/kernel": (["decoder_Ux"], None), + "decoder/attention_decoder/OrthoGRUCell/candidate/input_proj/kernel": (["decoder_Wx"], None), + "decoder/attention_decoder/OrthoGRUCell/candidate/input_proj/bias": (["decoder_bx"], None), + "attention/attn_key_projection": (["decoder_Wc_att"], None), + "attention/attn_projection_bias": (["decoder_b_att"], None), + "attention/Attention/attn_query_projection": (["decoder_W_comb_att"], None), + "attention/attn_similarity_v": (["decoder_U_att"], squeeze), + "attention/attn_bias": (["decoder_c_tt"], squeeze), + "decoder/attention_decoder/cond_gru_2_cell/gates/state_proj/kernel": (["decoder_U_nl"], None), + "decoder/attention_decoder/cond_gru_2_cell/gates/input_proj/kernel": (["decoder_Wc"], None), + "decoder/attention_decoder/cond_gru_2_cell/gates/state_proj/bias": (["decoder_b_nl"], None), + "decoder/attention_decoder/cond_gru_2_cell/candidate/state_proj/kernel": (["decoder_Ux_nl"], None), + "decoder/attention_decoder/cond_gru_2_cell/candidate/input_proj/kernel": (["decoder_Wcx"], None), + "decoder/attention_decoder/cond_gru_2_cell/candidate/state_proj/bias": (["decoder_bx_nl"], None), + "decoder/attention_decoder/rnn_state/kernel": (["ff_logit_lstm_W"], None), + "decoder/attention_decoder/rnn_state/bias": (["ff_logit_lstm_b"], None), + "decoder/attention_decoder/prev_out/kernel": (["ff_logit_prev_W"], None), + "decoder/attention_decoder/prev_out/bias": (["ff_logit_prev_b"], None), + "decoder/attention_decoder/context/kernel": (["ff_logit_ctx_W"], None), + "decoder/attention_decoder/context/bias": (["ff_logit_ctx_b"], None) +} +# pylint: enable=line-too-long + +ENCODER_NAME = "encoder" +DECODER_NAME = "decoder" +ATTENTION_NAME = "attention" + + +def load_nematus_json(path: str) -> Dict: + with open(path, "r", encoding="utf-8") as f_json: + contents = json.load(f_json) + + prefix = os.path.realpath(os.path.dirname(path)) + config = { + "encoder_type": contents["encoder"], + "decoder_type": contents["decoder"], + "n_words_src": contents["n_words_src"], + "n_words_tgt": contents["n_words"], + "variables_file": contents["saveto"], + "rnn_size": contents["dim"], + "embedding_size": contents["dim_word"], + "src_vocabulary": os.path.join( + prefix, contents["dictionaries"][0]), + "tgt_vocabulary": os.path.join( + prefix, contents["dictionaries"][1]), + "max_length": contents["maxlen"] + } + + if config["encoder_type"] != "gru": + raise ValueError("Unsupported encoder type: {}" + .format(config["encoder_type"])) + + if config["decoder_type"] != "gru_cond": + raise ValueError("Unsupported decoder type: {}" + .format(config["decoder_type"])) + + if not os.path.isfile(config["src_vocabulary"]): + raise FileNotFoundError("Vocabulary file not found: {}" + .format(config["src_vocabulary"])) + + if not os.path.isfile(config["tgt_vocabulary"]): + raise FileNotFoundError("Vocabulary file not found: {}" + .format(config["tgt_vocabulary"])) + + return config + + +VOCABULARY_TEMPLATE = """\ +[vocabulary_{}] +class=vocabulary.from_nematus_json +path="{}" +max_size={} +pad_to_max_size=True +""" + +ENCODER_TEMPLATE = """\ +[encoder] +class=encoders.RecurrentEncoder +name="{}" +input_sequence= +rnn_size={} +rnn_cell="NematusGRU" +dropout_keep_prob=1.0 + +[input_sequence] +class=model.sequence.EmbeddedSequence +name="{}" +vocabulary= +data_id="source" +embedding_size={} +max_length={} +add_end_symbol=True +""" + + +def build_encoder(config: Dict) -> Tuple[RecurrentEncoder, str]: + vocabulary = from_nematus_json( + config["src_vocabulary"], max_size=config["n_words_src"], + pad_to_max_size=True) + + vocabulary_ini = VOCABULARY_TEMPLATE.format( + "src", config["src_vocabulary"], config["n_words_src"]) + + inp_seq_name = "{}_input".format(ENCODER_NAME) + inp_seq = EmbeddedSequence( + name=inp_seq_name, + vocabulary=vocabulary, + data_id="source", + embedding_size=config["embedding_size"]) + + encoder = RecurrentEncoder( + name=ENCODER_NAME, + input_sequence=inp_seq, + rnn_size=config["rnn_size"], + rnn_cell="NematusGRU") + + encoder_ini = ENCODER_TEMPLATE.format( + ENCODER_NAME, config["rnn_size"], + inp_seq_name, config["embedding_size"], config["max_length"]) + + return encoder, "\n".join([vocabulary_ini, encoder_ini]) + + +ATTENTION_TEMPLATE = """\ +[attention] +class=attention.Attention +name="{}" +encoder= +dropout_keep_prob=1.0 +""" + + +def build_attention(config: Dict, + encoder: RecurrentEncoder) -> Tuple[Attention, str]: + attention = Attention( + name=ATTENTION_NAME, + encoder=encoder) + + attention_ini = ATTENTION_TEMPLATE.format(ATTENTION_NAME) + + return attention, attention_ini + + +DECODER_TEMPLATE = """\ +[decoder] +class=decoders.Decoder +name="{}" +vocabulary= +data_id="target" +embedding_size={} +rnn_size={} +max_output_len={} +encoders=[] +encoder_projection= +attentions=[] +attention_on_input=False +conditional_gru=True +output_projection= +rnn_cell="NematusGRU" +dropout_keep_prob=1.0 + +[nematus_nonlinear] +class=decoders.output_projection.nematus_output +output_size={} +dropout_keep_prob=1.0 + +[nematus_mean] +class=decoders.encoder_projection.nematus_projection +dropout_keep_prob=1.0 +""" + + +def build_decoder(config: Dict, + attention: Attention, + encoder: RecurrentEncoder) -> Tuple[Decoder, str]: + vocabulary = from_nematus_json( + config["tgt_vocabulary"], + max_size=config["n_words_tgt"], + pad_to_max_size=True) + + vocabulary_ini = VOCABULARY_TEMPLATE.format( + "tgt", config["tgt_vocabulary"], config["n_words_tgt"]) + + decoder = Decoder( + name=DECODER_NAME, + vocabulary=vocabulary, + data_id="target", + max_output_len=config["max_length"], + embedding_size=config["embedding_size"], + rnn_size=config["rnn_size"], + encoders=[encoder], + attentions=[attention], + attention_on_input=False, + conditional_gru=True, + encoder_projection=nematus_projection(dropout_keep_prob=1.0), + output_projection=nematus_output(config["embedding_size"]), + rnn_cell="NematusGRU") + + decoder_ini = DECODER_TEMPLATE.format( + DECODER_NAME, config["embedding_size"], config["rnn_size"], + config["max_length"], config["embedding_size"]) + + return decoder, "\n".join([vocabulary_ini, decoder_ini]) + + +def build_model(config: Dict) -> Tuple[ + RecurrentEncoder, Attention, Decoder, str]: + encoder, encoder_cfg = build_encoder(config) + attention, attention_cfg = build_attention(config, encoder) + decoder, decoder_cfg = build_decoder(config, attention, encoder) + + ini = "\n".join([encoder_cfg, attention_cfg, decoder_cfg]) + + return ini + + +def load_nematus_file(path: str) -> Dict[str, np.ndarray]: + contents = np.load(path) + cnt_dict = dict(contents) + contents.close() + return cnt_dict + + +def assign_vars(variables: Dict[str, np.ndarray]) -> List[tf.Tensor]: + """For each variable in the map, assign the value from the dict""" + + trainable_vars = tf.trainable_variables() + assign_ops = [] + + for var in trainable_vars: + map_key = var.op.name + + if map_key not in VARIABLE_MAP: + raise ValueError("Map key {} not in variable map".format(map_key)) + + nem_var_list, fun = VARIABLE_MAP[map_key] + + for nem_var in nem_var_list: + if nem_var not in variables: + raise ValueError("Alleged nematus var {} not found in loaded " + "nematus vars.".format(nem_var)) + + if fun is None: + if len(nem_var_list) != 1: + raise ValueError( + "Var list for map key {} must have length 1. " + "Length {} found instead." + .format(map_key, len(nem_var_list))) + + to_assign = variables[nem_var_list[0]] + else: + to_assign = fun([variables[v] for v in nem_var_list]) + + check_shape(var, to_assign) + assign_ops.append(tf.assign(var, to_assign)) + + return assign_ops + + +INI_HEADER = """\ +; This is an automatically generated configuration file +; for running imported nematus model +; For further training, set the configuration as appropriate + +[main] +name="nematus imported translation" +tf_manager= +output="{}" +runners=[] +postprocess=None +evaluation=[("target", evaluators.bleu.BLEU)] +runners_batch_size=1 + +; TODO Set these additional attributes for further training +; batch_size=80 +; epochs=10 +; train_dataset= +; val_dataset= +; trainer= +; logging_period=20 +; validation_period=60 +; random_seed=1234 + +; [train_data] +; class=dataset.load_dataset_from_files +; s_source="PATH/TO/DATA" ; TODO do not forget to fill this out! +; s_target="PATH/TO/DATA" ; TODO do not forget to fill this out! +; lazy=True + +; [val_data] +; class=dataset.load_dataset_from_files +; s_source="PATH/TO/DATA" ; TODO do not forget to fill this out! +; s_target="PATH/TO/DATA" ; TODO do not forget to fill this out! + +; [trainer] +; class=trainers.cross_entropy_trainer.CrossEntropyTrainer +; decoders=[] +; l2_weight=1.0e-8 +; clip_norm=1.0 + +[tf_manager] +class=tf_manager.TensorFlowManager +num_threads=4 +num_sessions=1 + +[runner] +class=runners.runner.GreedyRunner +decoder= +output_series="target" +""" + + +def write_config(experiment_dir: str, ini: str) -> None: + experiment_file = os.path.join(experiment_dir, "experiment.ini") + with open(experiment_file, "w", encoding="utf-8") as f_out: + f_out.write(INI_HEADER.format(experiment_dir)) + f_out.write(ini) + + +def prepare_output_dir(output_dir: str) -> bool: + if os.path.isdir(output_dir): + log("Directory {} already exists. Choose a nonexistent one.". + format(output_dir)) + exit(1) + + os.mkdir(output_dir) + + +def main() -> None: + log("Script started.") + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("nematus_json", metavar="NEMATUS-JSON", + help="nematus json file") + parser.add_argument("nematus_variables", metavar="NEMATUS-FILE", + help="nematus variable file") + parser.add_argument("output_dir", metavar="OUTPUT-DIR", + help="output directory") + args = parser.parse_args() + + log("Loading nematus variables from {}.".format(args.nematus_variables)) + nematus_vars = load_nematus_file(args.nematus_variables) + + log("Loading nematus JSON config from {}.".format(args.nematus_json)) + nematus_json_cfg = load_nematus_json(args.nematus_json) + + log("Bulding model.") + ini = build_model(nematus_json_cfg) + + log("Defining assign Ops.") + assign_ops = assign_vars(nematus_vars) + + log("Preparing output directory {}".format(args.output_dir)) + prepare_output_dir(args.output_dir) + + log("Writing configuration file to {}/experiment.ini." + .format(args.output_dir)) + write_config(args.output_dir, ini) + + log("Creating TF session.") + s = tf.Session() + + log("Running session to assign to Neural Monkey variables.") + s.run(assign_ops) + + log("Initializing saver.") + saver = tf.train.Saver() + + variables_file = os.path.join(args.output_dir, "variables.data") + log("Saving variables to {}".format(variables_file)) + saver.save(s, variables_file) + + log("Finished.") + + +if __name__ == "__main__": + main() diff --git a/tests/small.ini b/tests/small.ini index 7ff4f28bb..512f711f8 100644 --- a/tests/small.ini +++ b/tests/small.ini @@ -59,6 +59,7 @@ embedding_size=11 dropout_keep_prob=0.5 data_id="source" vocabulary= +rnn_cell="NematusGRU" [attention] class=attention.Attention @@ -81,6 +82,8 @@ dropout_keep_prob=0.5 data_id="target" max_output_len=1 vocabulary= +attention_on_input=False +rnn_cell="NematusGRU" [trainer] ; This block just fills the arguments of the trainer __init__ method.