-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
6,341 additions
and
0 deletions.
There are no files selected for viewing
Binary file added
BIN
+163 Bytes
Basic/Albert/albert_tiny_tf/albert/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
import tensorflow as tf | ||
|
||
tf.logging.set_verbosity(tf.logging.INFO) | ||
|
||
file_path = os.path.dirname(__file__) | ||
|
||
|
||
#模型目录 | ||
model_dir = os.path.join(file_path, 'albert_lcqmc_checkpoints/') | ||
|
||
#config文件 | ||
config_name = os.path.join(file_path, 'albert_config/albert_config.json') | ||
#ckpt文件名称 | ||
ckpt_name = os.path.join(model_dir, 'model.ckpt') | ||
#输出文件目录 | ||
output_dir = os.path.join(file_path, 'albert_lcqmc_checkpoints/') | ||
#vocab文件目录 | ||
vocab_file = os.path.join(file_path, 'albert_config/vocab.txt') | ||
#数据目录 | ||
data_dir = os.path.join(file_path, 'data/') | ||
|
||
num_train_epochs = 10 | ||
batch_size = 128 | ||
learning_rate = 0.00005 | ||
|
||
# gpu使用率 | ||
gpu_memory_fraction = 0.8 | ||
|
||
# 默认取倒数第二层的输出值作为句向量 | ||
layer_indexes = [-2] | ||
|
||
# 序列的最大程度,单文本建议把该值调小 | ||
max_seq_len = 128 | ||
|
||
# graph名字 | ||
graph_file = os.path.join(file_path, 'albert_lcqmc_checkpoints/graph') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import collections | ||
import copy | ||
import json | ||
import math | ||
import re | ||
import six | ||
import tensorflow as tf | ||
|
||
|
||
def get_shape_list(tensor, expected_rank=None, name=None): | ||
"""Returns a list of the shape of tensor, preferring static dimensions. | ||
Args: | ||
tensor: A tf.Tensor object to find the shape of. | ||
expected_rank: (optional) int. The expected rank of `tensor`. If this is | ||
specified and the `tensor` has a different rank, and exception will be | ||
thrown. | ||
name: Optional name of the tensor for the error message. | ||
Returns: | ||
A list of dimensions of the shape of tensor. All static dimensions will | ||
be returned as python integers, and dynamic dimensions will be returned | ||
as tf.Tensor scalars. | ||
""" | ||
if name is None: | ||
name = tensor.name | ||
|
||
if expected_rank is not None: | ||
assert_rank(tensor, expected_rank, name) | ||
|
||
shape = tensor.shape.as_list() | ||
|
||
non_static_indexes = [] | ||
for (index, dim) in enumerate(shape): | ||
if dim is None: | ||
non_static_indexes.append(index) | ||
|
||
if not non_static_indexes: | ||
return shape | ||
|
||
dyn_shape = tf.shape(tensor) | ||
for index in non_static_indexes: | ||
shape[index] = dyn_shape[index] | ||
return shape | ||
|
||
|
||
def reshape_to_matrix(input_tensor): | ||
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" | ||
ndims = input_tensor.shape.ndims | ||
if ndims < 2: | ||
raise ValueError("Input tensor must have at least rank 2. Shape = %s" % | ||
(input_tensor.shape)) | ||
if ndims == 2: | ||
return input_tensor | ||
|
||
width = input_tensor.shape[-1] | ||
output_tensor = tf.reshape(input_tensor, [-1, width]) | ||
return output_tensor | ||
|
||
|
||
def reshape_from_matrix(output_tensor, orig_shape_list): | ||
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" | ||
if len(orig_shape_list) == 2: | ||
return output_tensor | ||
|
||
output_shape = get_shape_list(output_tensor) | ||
|
||
orig_dims = orig_shape_list[0:-1] | ||
width = output_shape[-1] | ||
|
||
return tf.reshape(output_tensor, orig_dims + [width]) | ||
|
||
|
||
def assert_rank(tensor, expected_rank, name=None): | ||
"""Raises an exception if the tensor rank is not of the expected rank. | ||
Args: | ||
tensor: A tf.Tensor to check the rank of. | ||
expected_rank: Python integer or list of integers, expected rank. | ||
name: Optional name of the tensor for the error message. | ||
Raises: | ||
ValueError: If the expected shape doesn't match the actual shape. | ||
""" | ||
if name is None: | ||
name = tensor.name | ||
|
||
expected_rank_dict = {} | ||
if isinstance(expected_rank, six.integer_types): | ||
expected_rank_dict[expected_rank] = True | ||
else: | ||
for x in expected_rank: | ||
expected_rank_dict[x] = True | ||
|
||
actual_rank = tensor.shape.ndims | ||
if actual_rank not in expected_rank_dict: | ||
scope_name = tf.get_variable_scope().name | ||
raise ValueError( | ||
"For the tensor `%s` in scope `%s`, the actual rank " | ||
"`%d` (shape = %s) is not equal to the expected rank `%s`" % | ||
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) | ||
|
||
|
||
def gather_indexes(sequence_tensor, positions): | ||
"""Gathers the vectors at the specific positions over a minibatch.""" | ||
sequence_shape = get_shape_list(sequence_tensor, expected_rank=3) | ||
batch_size = sequence_shape[0] | ||
seq_length = sequence_shape[1] | ||
width = sequence_shape[2] | ||
|
||
flat_offsets = tf.reshape( | ||
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) | ||
flat_positions = tf.reshape(positions + flat_offsets, [-1]) | ||
flat_sequence_tensor = tf.reshape(sequence_tensor, | ||
[batch_size * seq_length, width]) | ||
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) | ||
return output_tensor | ||
|
||
|
||
# add sequence mask for: | ||
# 1. random shuffle lm modeling---xlnet with random shuffled input | ||
# 2. left2right and right2left language modeling | ||
# 3. conditional generation | ||
def generate_seq2seq_mask(attention_mask, mask_sequence, seq_type, **kargs): | ||
if seq_type == 'seq2seq': | ||
if mask_sequence is not None: | ||
seq_shape = get_shape_list(mask_sequence, expected_rank=2) | ||
seq_len = seq_shape[1] | ||
ones = tf.ones((1, seq_len, seq_len)) | ||
a_mask = tf.matrix_band_part(ones, -1, 0) | ||
s_ex12 = tf.expand_dims(tf.expand_dims(mask_sequence, 1), 2) | ||
s_ex13 = tf.expand_dims(tf.expand_dims(mask_sequence, 1), 3) | ||
a_mask = (1 - s_ex13) * (1 - s_ex12) + s_ex13 * a_mask | ||
# generate mask of batch x seq_len x seq_len | ||
a_mask = tf.reshape(a_mask, (-1, seq_len, seq_len)) | ||
out_mask = attention_mask * a_mask | ||
else: | ||
ones = tf.ones_like(attention_mask[:1]) | ||
mask = (tf.matrix_band_part(ones, -1, 0)) | ||
out_mask = attention_mask * mask | ||
else: | ||
out_mask = attention_mask | ||
|
||
return out_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
|
||
BERT_BASE_DIR=./albert_config | ||
python3 create_pretraining_data.py --do_whole_word_mask=True --input_file=data/news_zh_1.txt \ | ||
--output_file=data/tf_news_2016_zh_raw_news2016zh_1.tfrecord --vocab_file=$BERT_BASE_DIR/vocab.txt --do_lower_case=True \ | ||
--max_seq_length=512 --max_predictions_per_seq=51 --masked_lm_prob=0.10 |
Oops, something went wrong.