From a41de5e80eb40e2464bf421754ba7e7372bbc232 Mon Sep 17 00:00:00 2001 From: "xiangpeng.wxp" Date: Fri, 5 Aug 2022 23:48:46 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]Merge=20request=20from=20?= =?UTF-8?q?=E9=B9=8F=E7=A8=8B:nlp=5Ftranslation=5Ffinetune?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * csanmt finetune wxp --- .../models/nlp/csanmt_for_translation.py | 516 +++++++++++++++++- modelscope/pipelines/builder.py | 2 +- .../pipelines/nlp/translation_pipeline.py | 98 ++-- modelscope/trainers/nlp/__init__.py | 5 +- .../nlp/csanmt_translation_trainer.py | 324 +++++++++++ tests/pipelines/test_csanmt_translation.py | 10 +- tests/trainers/test_translation_trainer.py | 18 + 7 files changed, 916 insertions(+), 57 deletions(-) create mode 100644 modelscope/trainers/nlp/csanmt_translation_trainer.py create mode 100644 tests/trainers/test_translation_trainer.py diff --git a/modelscope/models/nlp/csanmt_for_translation.py b/modelscope/models/nlp/csanmt_for_translation.py index 41abd701..6906f41c 100644 --- a/modelscope/models/nlp/csanmt_for_translation.py +++ b/modelscope/models/nlp/csanmt_for_translation.py @@ -21,9 +21,11 @@ class CsanmtForTranslation(Model): params (dict): the model configuration. """ super().__init__(model_dir, *args, **kwargs) - self.params = kwargs + self.params = kwargs['params'] - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + def __call__(self, + input: Dict[str, Tensor], + label: Dict[str, Tensor] = None) -> Dict[str, Tensor]: """return the result by the model Args: @@ -32,12 +34,32 @@ class CsanmtForTranslation(Model): Returns: output_seqs: output sequence of target ids """ - with tf.compat.v1.variable_scope('NmtModel'): - output_seqs, output_scores = self.beam_search(input, self.params) - return { - 'output_seqs': output_seqs, - 'output_scores': output_scores, - } + if label is None: + with tf.compat.v1.variable_scope('NmtModel'): + output_seqs, output_scores = self.beam_search( + input, self.params) + return { + 'output_seqs': output_seqs, + 'output_scores': output_scores, + } + else: + train_op, loss = self.transformer_model_train_fn(input, label) + return { + 'train_op': train_op, + 'loss': loss, + } + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Run the forward pass for a model. + + Args: + input (Dict[str, Tensor]): the dict of the model inputs for the forward method + + Returns: + Dict[str, Tensor]: output from the model forward pass + """ + ... def encoding_graph(self, features, params): src_vocab_size = params['src_vocab_size'] @@ -137,6 +159,278 @@ class CsanmtForTranslation(Model): params) return encoder_output + def build_contrastive_training_graph(self, features, labels, params): + # representations + source_name = 'source' + target_name = 'target' + if params['shared_source_target_embedding']: + source_name = None + target_name = None + feature_output = self.semantic_encoding_graph( + features, params, name=source_name) + label_output = self.semantic_encoding_graph( + labels, params, name=target_name) + + return feature_output, label_output + + def MGMC_sampling(self, x_embedding, y_embedding, params, epsilon=1e-12): + K = params['num_of_samples'] + eta = params['eta'] + assert K % 2 == 0 + + def get_samples(x_vector, y_vector): + bias_vector = y_vector - x_vector + w_r = tf.math.divide( + tf.abs(bias_vector) - tf.reduce_min( + input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) + + epsilon, + tf.reduce_max( + input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) + - tf.reduce_min( + input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) + + 2 * epsilon) + + R = [] + for i in range(K // 2): + omega = eta * tf.random.normal(tf.shape(input=bias_vector), 0.0, w_r) + \ + (1.0 - eta) * tf.random.normal(tf.shape(input=bias_vector), 0.0, 1.0) + sample = x_vector + omega * bias_vector + R.append(sample) + return R + + ALL_SAMPLES = [] + ALL_SAMPLES = get_samples(x_embedding, y_embedding) + ALL_SAMPLES.extend(get_samples(y_embedding, x_embedding)) + + assert len(ALL_SAMPLES) == K + + return tf.concat(ALL_SAMPLES, axis=0) + + def decoding_graph(self, + encoder_output, + encoder_self_attention_bias, + labels, + params={}, + embedding_augmentation=None): + trg_vocab_size = params['trg_vocab_size'] + hidden_size = params['hidden_size'] + + initializer = tf.compat.v1.random_normal_initializer( + 0.0, hidden_size**-0.5, dtype=tf.float32) + + if params['shared_source_target_embedding']: + with tf.compat.v1.variable_scope( + 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE): + trg_embedding = tf.compat.v1.get_variable( + 'Weights', [trg_vocab_size, hidden_size], + initializer=initializer) + else: + with tf.compat.v1.variable_scope('Target_Embedding'): + trg_embedding = tf.compat.v1.get_variable( + 'Weights', [trg_vocab_size, hidden_size], + initializer=initializer) + + eos_padding = tf.zeros([tf.shape(input=labels)[0], 1], tf.int64) + trg_seq = tf.concat([labels, eos_padding], 1) + trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32) + shift_trg_mask = trg_mask[:, :-1] + shift_trg_mask = tf.pad( + tensor=shift_trg_mask, + paddings=[[0, 0], [1, 0]], + constant_values=1) + + decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32)) + + decoder_input *= hidden_size**0.5 + decoder_self_attention_bias = attention_bias( + tf.shape(input=decoder_input)[1], 'causal') + decoder_input = tf.pad( + tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :] + if params['position_info_type'] == 'absolute': + decoder_input = add_timing_signal(decoder_input) + + decoder_input = tf.nn.dropout( + decoder_input, rate=1 - (1.0 - params['residual_dropout'])) + + # training + decoder_output, attention_weights = transformer_decoder( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_self_attention_bias, + states_key=None, + states_val=None, + embedding_augmentation=embedding_augmentation, + params=params) + + logits = self.prediction(decoder_output, params) + + on_value = params['confidence'] + off_value = (1.0 - params['confidence']) / tf.cast( + trg_vocab_size - 1, dtype=tf.float32) + soft_targets = tf.one_hot( + tf.cast(trg_seq, tf.int32), + depth=trg_vocab_size, + on_value=on_value, + off_value=off_value) + mask = tf.cast(shift_trg_mask, logits.dtype) + xentropy = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=tf.stop_gradient(soft_targets)) * mask + loss = tf.reduce_sum(input_tensor=xentropy) / tf.reduce_sum( + input_tensor=mask) + + return loss + + def build_training_graph(self, + features, + labels, + params, + feature_embedding=None, + label_embedding=None): + # encode + encoder_output, encoder_self_attention_bias = self.encoding_graph( + features, params) + embedding_augmentation = None + if feature_embedding is not None and label_embedding is not None: + embedding_augmentation = self.MGMC_sampling( + feature_embedding, label_embedding, params) + + encoder_output = tf.tile(encoder_output, + [params['num_of_samples'], 1, 1]) + encoder_self_attention_bias = tf.tile( + encoder_self_attention_bias, + [params['num_of_samples'], 1, 1, 1]) + labels = tf.tile(labels, [params['num_of_samples'], 1]) + + # decode + loss = self.decoding_graph( + encoder_output, + encoder_self_attention_bias, + labels, + params, + embedding_augmentation=embedding_augmentation) + + return loss + + def transformer_model_train_fn(self, features, labels): + initializer = get_initializer(self.params) + with tf.compat.v1.variable_scope('NmtModel', initializer=initializer): + num_gpus = self.params['num_gpus'] + gradient_clip_norm = self.params['gradient_clip_norm'] + global_step = tf.compat.v1.train.get_global_step() + print(global_step) + + # learning rate + learning_rate = get_learning_rate_decay( + self.params['learning_rate'], global_step, self.params) + learning_rate = tf.convert_to_tensor( + value=learning_rate, dtype=tf.float32) + + # optimizer + if self.params['optimizer'] == 'sgd': + optimizer = tf.compat.v1.train.GradientDescentOptimizer( + learning_rate) + elif self.params['optimizer'] == 'adam': + optimizer = tf.compat.v1.train.AdamOptimizer( + learning_rate=learning_rate, + beta1=self.params['adam_beta1'], + beta2=self.params['adam_beta2'], + epsilon=self.params['adam_epsilon']) + else: + tf.compat.v1.logging.info('optimizer not supported') + sys.exit() + opt = MultiStepOptimizer(optimizer, self.params['update_cycle']) + + def fill_gpus(inputs, num_gpus): + outputs = inputs + for i in range(num_gpus): + outputs = tf.concat([outputs, inputs], axis=0) + outputs = outputs[:num_gpus, ] + return outputs + + features = tf.cond( + pred=tf.shape(input=features)[0] < num_gpus, + true_fn=lambda: fill_gpus(features, num_gpus), + false_fn=lambda: features) + labels = tf.cond( + pred=tf.shape(input=labels)[0] < num_gpus, + true_fn=lambda: fill_gpus(labels, num_gpus), + false_fn=lambda: labels) + + if num_gpus > 0: + feature_shards = shard_features(features, num_gpus) + label_shards = shard_features(labels, num_gpus) + else: + feature_shards = [features] + label_shards = [labels] + + if num_gpus > 0: + devices = ['gpu:%d' % d for d in range(num_gpus)] + else: + devices = ['cpu:0'] + multi_grads = [] + sharded_losses = [] + + for i, device in enumerate(devices): + with tf.device(device), tf.compat.v1.variable_scope( + tf.compat.v1.get_variable_scope(), + reuse=True if i > 0 else None): + with tf.name_scope('%s_%d' % ('GPU', i)): + feature_output, label_output = self.build_contrastive_training_graph( + feature_shards[i], label_shards[i], self.params) + mle_loss = self.build_training_graph( + feature_shards[i], label_shards[i], self.params, + feature_output, label_output) + sharded_losses.append(mle_loss) + tf.compat.v1.summary.scalar('mle_loss_{}'.format(i), + mle_loss) + + # Optimization + trainable_vars_list = [ + v for v in tf.compat.v1.trainable_variables() + if 'Shared_Semantic_Embedding' not in v.name + and 'mini_xlm_encoder' not in v.name + ] + grads_and_vars = opt.compute_gradients( + mle_loss, + var_list=trainable_vars_list, + colocate_gradients_with_ops=True) + multi_grads.append(grads_and_vars) + + total_loss = tf.add_n(sharded_losses) / len(sharded_losses) + + # Average gradients + grads_and_vars = average_gradients(multi_grads) + + if gradient_clip_norm > 0.0: + grads, var_list = list(zip(*grads_and_vars)) + grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm) + grads_and_vars = zip(grads, var_list) + + train_op = opt.apply_gradients( + grads_and_vars, + global_step=tf.compat.v1.train.get_global_step()) + + return train_op, total_loss + + def prediction(self, decoder_output, params): + hidden_size = params['hidden_size'] + trg_vocab_size = params['trg_vocab_size'] + + if params['shared_embedding_and_softmax_weights']: + embedding_scope = 'Shared_Embedding' if params[ + 'shared_source_target_embedding'] else 'Target_Embedding' + with tf.compat.v1.variable_scope(embedding_scope, reuse=True): + weights = tf.compat.v1.get_variable('Weights') + else: + weights = tf.compat.v1.get_variable('Softmax', + [tgt_vocab_size, hidden_size]) + shape = tf.shape(input=decoder_output)[:-1] + decoder_output = tf.reshape(decoder_output, [-1, hidden_size]) + logits = tf.matmul(decoder_output, weights, transpose_b=True) + logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0)) + return logits + def inference_func(self, encoder_output, feature_output, @@ -193,7 +487,7 @@ class CsanmtForTranslation(Model): weights = tf.compat.v1.get_variable('Weights') else: weights = tf.compat.v1.get_variable('Softmax', - [tgt_vocab_size, hidden_size]) + [trg_vocab_size, hidden_size]) logits = tf.matmul(decoder_output_last, weights, transpose_b=True) log_prob = tf.nn.log_softmax(logits) return log_prob, attention_weights_last, states_key, states_val @@ -212,7 +506,11 @@ class CsanmtForTranslation(Model): encoder_output, encoder_self_attention_bias = self.encoding_graph( features, params) - feature_output = self.semantic_encoding_graph(features, params) + source_name = 'source' + if params['shared_source_target_embedding']: + source_name = None + feature_output = self.semantic_encoding_graph( + features, params, name=source_name) init_seqs = tf.fill([batch_size, beam_size, 1], 0) init_log_probs = \ @@ -585,7 +883,6 @@ def _residual_fn(x, y, keep_prob=None): def embedding_augmentation_layer(x, embedding_augmentation, params, name=None): hidden_size = params['hidden_size'] keep_prob = 1.0 - params['relu_dropout'] - layer_postproc = params['layer_postproc'] with tf.compat.v1.variable_scope( name, default_name='embedding_augmentation_layer', @@ -600,8 +897,7 @@ def embedding_augmentation_layer(x, embedding_augmentation, params, name=None): with tf.compat.v1.variable_scope('output_layer'): output = linear(hidden, hidden_size, True, True) - x = _layer_process(x + output, layer_postproc) - return x + return x + output def transformer_ffn_layer(x, params, name=None): @@ -740,8 +1036,10 @@ def transformer_decoder(decoder_input, if params['position_info_type'] == 'relative' else None # continuous semantic augmentation if embedding_augmentation is not None: - x = embedding_augmentation_layer(x, embedding_augmentation, - params) + x = embedding_augmentation_layer( + x, _layer_process(embedding_augmentation, + layer_preproc), params) + x = _layer_process(x, layer_postproc) o, w = multihead_attention( _layer_process(x, layer_preproc), None, @@ -1004,3 +1302,191 @@ def multihead_attention(queries, w = tf.reduce_mean(w, 1) x = linear(x, output_depth, True, True, scope='output_transform') return x, w + + +def get_initializer(params): + if params['initializer'] == 'uniform': + max_val = params['initializer_scale'] + return tf.compat.v1.random_uniform_initializer(-max_val, max_val) + elif params['initializer'] == 'normal': + return tf.compat.v1.random_normal_initializer( + 0.0, params['initializer_scale']) + elif params['initializer'] == 'normal_unit_scaling': + return tf.compat.v1.variance_scaling_initializer( + params['initializer_scale'], mode='fan_avg', distribution='normal') + elif params['initializer'] == 'uniform_unit_scaling': + return tf.compat.v1.variance_scaling_initializer( + params['initializer_scale'], + mode='fan_avg', + distribution='uniform') + else: + raise ValueError('Unrecognized initializer: %s' + % params['initializer']) + + +def get_learning_rate_decay(learning_rate, global_step, params): + if params['learning_rate_decay'] in ['linear_warmup_rsqrt_decay', 'noam']: + step = tf.cast(global_step, dtype=tf.float32) + warmup_steps = tf.cast(params['warmup_steps'], dtype=tf.float32) + multiplier = params['hidden_size']**-0.5 + decay = multiplier * tf.minimum((step + 1) * (warmup_steps**-1.5), + (step + 1)**-0.5) + return learning_rate * decay + elif params['learning_rate_decay'] == 'piecewise_constant': + return tf.compat.v1.train.piecewise_constant( + tf.cast(global_step, dtype=tf.int32), + params['learning_rate_boundaries'], params['learning_rate_values']) + elif params['learning_rate_decay'] == 'none': + return learning_rate + else: + raise ValueError('Unknown learning_rate_decay') + + +def average_gradients(tower_grads): + average_grads = [] + for grad_and_vars in zip(*tower_grads): + grads = [] + for g, _ in grad_and_vars: + expanded_g = tf.expand_dims(g, 0) + grads.append(expanded_g) + grad = tf.concat(axis=0, values=grads) + grad = tf.reduce_mean(grad, 0) + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + return average_grads + + +_ENGINE = None + + +def all_reduce(tensor): + if _ENGINE is None: + return tensor + + return _ENGINE.allreduce(tensor, compression=_ENGINE.Compression.fp16) + + +class MultiStepOptimizer(tf.compat.v1.train.Optimizer): + + def __init__(self, + optimizer, + step=1, + use_locking=False, + name='MultiStepOptimizer'): + super(MultiStepOptimizer, self).__init__(use_locking, name) + self._optimizer = optimizer + self._step = step + self._step_t = tf.convert_to_tensor(step, name='step') + + def _all_reduce(self, tensor): + with tf.name_scope(self._name + '_Allreduce'): + if tensor is None: + return tensor + + if isinstance(tensor, tf.IndexedSlices): + tensor = tf.convert_to_tensor(tensor) + + return all_reduce(tensor) + + def compute_gradients(self, + loss, + var_list=None, + gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + grads_and_vars = self._optimizer.compute_gradients( + loss, var_list, gate_gradients, aggregation_method, + colocate_gradients_with_ops, grad_loss) + + grads, var_list = list(zip(*grads_and_vars)) + + # Do not create extra variables when step is 1 + if self._step == 1: + grads = [self._all_reduce(t) for t in grads] + return list(zip(grads, var_list)) + + first_var = min(var_list, key=lambda x: x.name) + iter_var = self._create_non_slot_variable( + initial_value=0 if self._step == 1 else 1, + name='iter', + colocate_with=first_var) + + new_grads = [] + + for grad, var in zip(grads, var_list): + grad_acc = self._zeros_slot(var, 'grad_acc', self._name) + + if isinstance(grad, tf.IndexedSlices): + grad_acc = tf.scatter_add( + grad_acc, + grad.indices, + grad.values, + use_locking=self._use_locking) + else: + grad_acc = tf.assign_add( + grad_acc, grad, use_locking=self._use_locking) + + def _acc_grad(): + return grad_acc + + def _avg_grad(): + return self._all_reduce(grad_acc / self._step) + + grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad) + new_grads.append(grad) + + return list(zip(new_grads, var_list)) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + if self._step == 1: + return self._optimizer.apply_gradients( + grads_and_vars, global_step, name=name) + + grads, var_list = list(zip(*grads_and_vars)) + + def _pass_gradients(): + return tf.group(*grads) + + def _apply_gradients(): + op = self._optimizer.apply_gradients( + zip(grads, var_list), global_step, name) + with tf.control_dependencies([op]): + zero_ops = [] + for var in var_list: + grad_acc = self.get_slot(var, 'grad_acc') + zero_ops.append( + grad_acc.assign( + tf.zeros_like(grad_acc), + use_locking=self._use_locking)) + zero_op = tf.group(*zero_ops) + return tf.group(*[op, zero_op]) + + iter_var = self._get_non_slot_variable('iter', tf.get_default_graph()) + update_op = tf.cond( + tf.equal(iter_var, 0), _apply_gradients, _pass_gradients) + + with tf.control_dependencies([update_op]): + iter_op = iter_var.assign( + tf.mod(iter_var + 1, self._step_t), + use_locking=self._use_locking) + + return tf.group(*[update_op, iter_op]) + + +def shard_features(x, num_datashards): + x = tf.convert_to_tensor(x) + batch_size = tf.shape(x)[0] + size_splits = [] + + with tf.device('/cpu:0'): + for i in range(num_datashards): + size_splits.append( + tf.cond( + tf.greater( + tf.compat.v1.mod(batch_size, num_datashards), + i), lambda: batch_size // num_datashards + 1, + lambda: batch_size // num_datashards)) + + return tf.split(x, size_splits, axis=0) diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index ea18d7b7..743ba1cb 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -27,7 +27,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { (Pipelines.sentence_similarity, 'damo/nlp_structbert_sentence-similarity_chinese-base'), Tasks.translation: (Pipelines.csanmt_translation, - 'damo/nlp_csanmt_translation'), + 'damo/nlp_csanmt_translation_zh2en'), Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), Tasks.sentiment_classification: (Pipelines.sentiment_classification, diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index dba3fe9f..67ff3927 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -1,4 +1,5 @@ import os.path as osp +from threading import Lock from typing import Any, Dict import numpy as np @@ -8,59 +9,38 @@ from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.config import Config +from modelscope.utils.constant import Frameworks, ModelFile, Tasks from modelscope.utils.logger import get_logger if tf.__version__ >= '2.0': tf = tf.compat.v1 -tf.disable_eager_execution() + tf.disable_eager_execution() logger = get_logger() __all__ = ['TranslationPipeline'] -# constant -PARAMS = { - 'hidden_size': 512, - 'filter_size': 2048, - 'num_heads': 8, - 'num_encoder_layers': 6, - 'num_decoder_layers': 6, - 'attention_dropout': 0.0, - 'residual_dropout': 0.0, - 'relu_dropout': 0.0, - 'layer_preproc': 'none', - 'layer_postproc': 'layer_norm', - 'shared_embedding_and_softmax_weights': True, - 'shared_source_target_embedding': True, - 'initializer_scale': 0.1, - 'train_max_len': 100, - 'confidence': 0.9, - 'position_info_type': 'absolute', - 'max_relative_dis': 16, - 'beam_size': 4, - 'lp_rate': 0.6, - 'num_semantic_encoder_layers': 4, - 'max_decoded_trg_len': 100, - 'src_vocab_size': 37006, - 'trg_vocab_size': 37006, - 'vocab_src': 'src_vocab.txt', - 'vocab_trg': 'trg_vocab.txt' -} - @PIPELINES.register_module( Tasks.translation, module_name=Pipelines.csanmt_translation) class TranslationPipeline(Pipeline): def __init__(self, model: str, **kwargs): - super().__init__(model=model) - model = self.model.model_dir tf.reset_default_graph() + self.framework = Frameworks.tf + self.device_name = 'cpu' + + super().__init__(model=model) + model_path = osp.join( osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') - self.params = PARAMS + self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION)) + + self.params = {} + self._override_params_from_file() + self._src_vocab_path = osp.join(model, self.params['vocab_src']) self._src_vocab = dict([ (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) @@ -70,15 +50,16 @@ class TranslationPipeline(Pipeline): (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) ]) - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) self.input_wids = tf.placeholder( dtype=tf.int64, shape=[None, None], name='input_wids') self.output = {} # model + self.model = CsanmtForTranslation(model_path, params=self.params) output = self.model(self.input_wids) self.output.update(output) @@ -88,6 +69,49 @@ class TranslationPipeline(Pipeline): model_loader = tf.train.Saver(tf.global_variables()) model_loader.restore(sess, model_path) + def _override_params_from_file(self): + + # model + self.params['hidden_size'] = self.cfg['model']['hidden_size'] + self.params['filter_size'] = self.cfg['model']['filter_size'] + self.params['num_heads'] = self.cfg['model']['num_heads'] + self.params['num_encoder_layers'] = self.cfg['model'][ + 'num_encoder_layers'] + self.params['num_decoder_layers'] = self.cfg['model'][ + 'num_decoder_layers'] + self.params['layer_preproc'] = self.cfg['model']['layer_preproc'] + self.params['layer_postproc'] = self.cfg['model']['layer_postproc'] + self.params['shared_embedding_and_softmax_weights'] = self.cfg[ + 'model']['shared_embedding_and_softmax_weights'] + self.params['shared_source_target_embedding'] = self.cfg['model'][ + 'shared_source_target_embedding'] + self.params['initializer_scale'] = self.cfg['model'][ + 'initializer_scale'] + self.params['position_info_type'] = self.cfg['model'][ + 'position_info_type'] + self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis'] + self.params['num_semantic_encoder_layers'] = self.cfg['model'][ + 'num_semantic_encoder_layers'] + self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size'] + self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size'] + self.params['attention_dropout'] = 0.0 + self.params['residual_dropout'] = 0.0 + self.params['relu_dropout'] = 0.0 + + # dataset + self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file'] + self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file'] + + # train + self.params['train_max_len'] = self.cfg['train']['train_max_len'] + self.params['confidence'] = self.cfg['train']['confidence'] + + # evaluation + self.params['beam_size'] = self.cfg['evaluation']['beam_size'] + self.params['lp_rate'] = self.cfg['evaluation']['lp_rate'] + self.params['max_decoded_trg_len'] = self.cfg['evaluation'][ + 'max_decoded_trg_len'] + def preprocess(self, input: str) -> Dict[str, Any]: input_ids = np.array([[ self._src_vocab[w] diff --git a/modelscope/trainers/nlp/__init__.py b/modelscope/trainers/nlp/__init__.py index 888f9941..7ab8fd70 100644 --- a/modelscope/trainers/nlp/__init__.py +++ b/modelscope/trainers/nlp/__init__.py @@ -5,10 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .sequence_classification_trainer import SequenceClassificationTrainer - + from .csanmt_translation_trainer import CsanmtTranslationTrainer else: _import_structure = { - 'sequence_classification_trainer': ['SequenceClassificationTrainer'] + 'sequence_classification_trainer': ['SequenceClassificationTrainer'], + 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], } import sys diff --git a/modelscope/trainers/nlp/csanmt_translation_trainer.py b/modelscope/trainers/nlp/csanmt_translation_trainer.py new file mode 100644 index 00000000..219c5ff1 --- /dev/null +++ b/modelscope/trainers/nlp/csanmt_translation_trainer.py @@ -0,0 +1,324 @@ +import os.path as osp +from typing import Dict, Optional + +import tensorflow as tf + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.nlp import CsanmtForTranslation +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + + +@TRAINERS.register_module(module_name=r'csanmt-translation') +class CsanmtTranslationTrainer(BaseTrainer): + + def __init__(self, model: str, cfg_file: str = None, *args, **kwargs): + if not osp.exists(model): + model = snapshot_download(model) + tf.reset_default_graph() + + self.model_dir = model + self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER) + if cfg_file is None: + cfg_file = osp.join(model, ModelFile.CONFIGURATION) + + super().__init__(cfg_file) + + self.params = {} + self._override_params_from_file() + + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + + self.source_wids = tf.placeholder( + dtype=tf.int64, shape=[None, None], name='source_wids') + self.target_wids = tf.placeholder( + dtype=tf.int64, shape=[None, None], name='target_wids') + self.output = {} + + self.global_step = tf.train.create_global_step() + + self.model = CsanmtForTranslation(self.model_path, params=self.params) + output = self.model(input=self.source_wids, label=self.target_wids) + self.output.update(output) + + self.model_saver = tf.train.Saver( + tf.global_variables(), + max_to_keep=self.params['keep_checkpoint_max']) + with self._session.as_default() as sess: + logger.info(f'loading model from {self.model_path}') + + pretrained_variables_map = get_pretrained_variables_map( + self.model_path) + + tf.train.init_from_checkpoint(self.model_path, + pretrained_variables_map) + sess.run(tf.global_variables_initializer()) + + def _override_params_from_file(self): + + self.params['hidden_size'] = self.cfg['model']['hidden_size'] + self.params['filter_size'] = self.cfg['model']['filter_size'] + self.params['num_heads'] = self.cfg['model']['num_heads'] + self.params['num_encoder_layers'] = self.cfg['model'][ + 'num_encoder_layers'] + self.params['num_decoder_layers'] = self.cfg['model'][ + 'num_decoder_layers'] + self.params['layer_preproc'] = self.cfg['model']['layer_preproc'] + self.params['layer_postproc'] = self.cfg['model']['layer_postproc'] + self.params['shared_embedding_and_softmax_weights'] = self.cfg[ + 'model']['shared_embedding_and_softmax_weights'] + self.params['shared_source_target_embedding'] = self.cfg['model'][ + 'shared_source_target_embedding'] + self.params['initializer_scale'] = self.cfg['model'][ + 'initializer_scale'] + self.params['position_info_type'] = self.cfg['model'][ + 'position_info_type'] + self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis'] + self.params['num_semantic_encoder_layers'] = self.cfg['model'][ + 'num_semantic_encoder_layers'] + self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size'] + self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size'] + self.params['attention_dropout'] = 0.0 + self.params['residual_dropout'] = 0.0 + self.params['relu_dropout'] = 0.0 + + self.params['train_src'] = self.cfg['dataset']['train_src'] + self.params['train_trg'] = self.cfg['dataset']['train_trg'] + self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file'] + self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file'] + + self.params['num_gpus'] = self.cfg['train']['num_gpus'] + self.params['warmup_steps'] = self.cfg['train']['warmup_steps'] + self.params['update_cycle'] = self.cfg['train']['update_cycle'] + self.params['keep_checkpoint_max'] = self.cfg['train'][ + 'keep_checkpoint_max'] + self.params['confidence'] = self.cfg['train']['confidence'] + self.params['optimizer'] = self.cfg['train']['optimizer'] + self.params['adam_beta1'] = self.cfg['train']['adam_beta1'] + self.params['adam_beta2'] = self.cfg['train']['adam_beta2'] + self.params['adam_epsilon'] = self.cfg['train']['adam_epsilon'] + self.params['gradient_clip_norm'] = self.cfg['train'][ + 'gradient_clip_norm'] + self.params['learning_rate_decay'] = self.cfg['train'][ + 'learning_rate_decay'] + self.params['initializer'] = self.cfg['train']['initializer'] + self.params['initializer_scale'] = self.cfg['train'][ + 'initializer_scale'] + self.params['learning_rate'] = self.cfg['train']['learning_rate'] + self.params['train_batch_size_words'] = self.cfg['train'][ + 'train_batch_size_words'] + self.params['scale_l1'] = self.cfg['train']['scale_l1'] + self.params['scale_l2'] = self.cfg['train']['scale_l2'] + self.params['train_max_len'] = self.cfg['train']['train_max_len'] + self.params['max_training_steps'] = self.cfg['train'][ + 'max_training_steps'] + self.params['save_checkpoints_steps'] = self.cfg['train'][ + 'save_checkpoints_steps'] + self.params['num_of_samples'] = self.cfg['train']['num_of_samples'] + self.params['eta'] = self.cfg['train']['eta'] + + self.params['beam_size'] = self.cfg['evaluation']['beam_size'] + self.params['lp_rate'] = self.cfg['evaluation']['lp_rate'] + self.params['max_decoded_trg_len'] = self.cfg['evaluation'][ + 'max_decoded_trg_len'] + + self.params['seed'] = self.cfg['model']['seed'] + + def train(self, *args, **kwargs): + logger.info('Begin csanmt training') + + train_src = osp.join(self.model_dir, self.params['train_src']) + train_trg = osp.join(self.model_dir, self.params['train_trg']) + vocab_src = osp.join(self.model_dir, self.params['vocab_src']) + vocab_trg = osp.join(self.model_dir, self.params['vocab_trg']) + + iteration = 0 + + with self._session.as_default() as tf_session: + while True: + iteration += 1 + if iteration >= self.params['max_training_steps']: + break + + train_input_fn = input_fn( + train_src, + train_trg, + vocab_src, + vocab_trg, + batch_size_words=self.params['train_batch_size_words'], + max_len=self.params['train_max_len'], + num_gpus=self.params['num_gpus'] + if self.params['num_gpus'] > 0 else 1, + is_train=True, + session=tf_session, + iteration=iteration) + + features, labels = train_input_fn + + features_batch, labels_batch = tf_session.run( + [features, labels]) + + feed_dict = { + self.source_wids: features_batch, + self.target_wids: labels_batch + } + sess_outputs = self._session.run( + self.output, feed_dict=feed_dict) + loss_step = sess_outputs['loss'] + logger.info('Iteration: {}, step loss: {:.6f}'.format( + iteration, loss_step)) + + if iteration % self.params['save_checkpoints_steps'] == 0: + tf.logging.info('%s: Saving model on step: %d.' % + (__name__, iteration)) + ck_path = self.model_dir + 'model.ckpt' + self.model_saver.save( + tf_session, + ck_path, + global_step=tf.train.get_global_step()) + + tf.logging.info('%s: NMT training completed at time: %s.') + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + pass + + +def input_fn(src_file, + trg_file, + src_vocab_file, + trg_vocab_file, + num_buckets=20, + max_len=100, + batch_size=200, + batch_size_words=4096, + num_gpus=1, + is_train=True, + session=None, + iteration=None): + src_vocab = tf.lookup.StaticVocabularyTable( + tf.lookup.TextFileInitializer( + src_vocab_file, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER), + num_oov_buckets=1) # NOTE unk-> vocab_size + trg_vocab = tf.lookup.StaticVocabularyTable( + tf.lookup.TextFileInitializer( + trg_vocab_file, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER), + num_oov_buckets=1) # NOTE unk-> vocab_size + src_dataset = tf.data.TextLineDataset(src_file) + trg_dataset = tf.data.TextLineDataset(trg_file) + src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset)) + src_trg_dataset = src_trg_dataset.map( + lambda src, trg: + (tf.string_split([src]).values, tf.string_split([trg]).values), + num_parallel_calls=10).prefetch(1000000) + src_trg_dataset = src_trg_dataset.map( + lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)), + num_parallel_calls=10).prefetch(1000000) + + if is_train: + + def key_func(src_data, trg_data): + bucket_width = (max_len + num_buckets - 1) // num_buckets + bucket_id = tf.maximum( + tf.size(input=src_data) // bucket_width, + tf.size(input=trg_data) // bucket_width) + return tf.cast(tf.minimum(num_buckets, bucket_id), dtype=tf.int64) + + def reduce_func(unused_key, windowed_data): + return windowed_data.padded_batch( + batch_size_words, padded_shapes=([None], [None])) + + def window_size_func(key): + bucket_width = (max_len + num_buckets - 1) // num_buckets + key += 1 + size = (num_gpus * batch_size_words // (key * bucket_width)) + return tf.cast(size, dtype=tf.int64) + + src_trg_dataset = src_trg_dataset.filter( + lambda src, trg: tf.logical_and( + tf.size(input=src) <= max_len, + tf.size(input=trg) <= max_len)) + src_trg_dataset = src_trg_dataset.apply( + tf.data.experimental.group_by_window( + key_func=key_func, + reduce_func=reduce_func, + window_size_func=window_size_func)) + + else: + src_trg_dataset = src_trg_dataset.padded_batch( + batch_size * num_gpus, padded_shapes=([None], [None])) + + iterator = tf.data.make_initializable_iterator(src_trg_dataset) + tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) + features, labels = iterator.get_next() + + if is_train: + session.run(iterator.initializer) + if iteration == 1: + session.run(tf.tables_initializer()) + return features, labels + + +def get_pretrained_variables_map(checkpoint_file_path, ignore_scope=None): + reader = tf.train.NewCheckpointReader( + tf.train.latest_checkpoint(checkpoint_file_path)) + saved_shapes = reader.get_variable_to_shape_map() + if ignore_scope is None: + var_names = sorted([(var.name, var.name.split(':')[0]) + for var in tf.global_variables() + if var.name.split(':')[0] in saved_shapes]) + else: + var_names = sorted([(var.name, var.name.split(':')[0]) + for var in tf.global_variables() + if var.name.split(':')[0] in saved_shapes and all( + scope not in var.name + for scope in ignore_scope)]) + restore_vars = [] + name2var = dict( + zip( + map(lambda x: x.name.split(':')[0], tf.global_variables()), + tf.global_variables())) + restore_map = {} + with tf.variable_scope('', reuse=True): + for var_name, saved_var_name in var_names: + curr_var = name2var[saved_var_name] + var_shape = curr_var.get_shape().as_list() + if var_shape == saved_shapes[saved_var_name]: + restore_vars.append(curr_var) + restore_map[saved_var_name] = curr_var + tf.logging.info('Restore paramter %s from %s ...' % + (saved_var_name, checkpoint_file_path)) + return restore_map diff --git a/tests/pipelines/test_csanmt_translation.py b/tests/pipelines/test_csanmt_translation.py index a5c29f16..c43011fc 100644 --- a/tests/pipelines/test_csanmt_translation.py +++ b/tests/pipelines/test_csanmt_translation.py @@ -4,19 +4,25 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TranslationPipeline from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level class TranslationTest(unittest.TestCase): - model_id = 'damo/nlp_csanmt_translation' - inputs = 'Gut@@ ach : Incre@@ ased safety for pedestri@@ ans' + model_id = 'damo/nlp_csanmt_translation_zh2en' + inputs = '声明 补充 说 , 沃伦 的 同事 都 深感 震惊 , 并且 希望 他 能够 投@@ 案@@ 自@@ 首 。' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) print(pipeline_ins(input=self.inputs)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.translation) + print(pipeline_ins(input=self.inputs)) + if __name__ == '__main__': unittest.main() diff --git a/tests/trainers/test_translation_trainer.py b/tests/trainers/test_translation_trainer.py new file mode 100644 index 00000000..71bed241 --- /dev/null +++ b/tests/trainers/test_translation_trainer.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.trainers.nlp import CsanmtTranslationTrainer +from modelscope.utils.test_utils import test_level + + +class TranslationTest(unittest.TestCase): + model_id = 'damo/nlp_csanmt_translation_zh2en' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + trainer = CsanmtTranslationTrainer(model=self.model_id) + trainer.train() + + +if __name__ == '__main__': + unittest.main()