* csanmt finetune wxpmaster
| @@ -21,9 +21,11 @@ class CsanmtForTranslation(Model): | |||||
| params (dict): the model configuration. | params (dict): the model configuration. | ||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.params = kwargs | |||||
| self.params = kwargs['params'] | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| def __call__(self, | |||||
| input: Dict[str, Tensor], | |||||
| label: Dict[str, Tensor] = None) -> Dict[str, Tensor]: | |||||
| """return the result by the model | """return the result by the model | ||||
| Args: | Args: | ||||
| @@ -32,12 +34,32 @@ class CsanmtForTranslation(Model): | |||||
| Returns: | Returns: | ||||
| output_seqs: output sequence of target ids | output_seqs: output sequence of target ids | ||||
| """ | """ | ||||
| with tf.compat.v1.variable_scope('NmtModel'): | |||||
| output_seqs, output_scores = self.beam_search(input, self.params) | |||||
| return { | |||||
| 'output_seqs': output_seqs, | |||||
| 'output_scores': output_scores, | |||||
| } | |||||
| if label is None: | |||||
| with tf.compat.v1.variable_scope('NmtModel'): | |||||
| output_seqs, output_scores = self.beam_search( | |||||
| input, self.params) | |||||
| return { | |||||
| 'output_seqs': output_seqs, | |||||
| 'output_scores': output_scores, | |||||
| } | |||||
| else: | |||||
| train_op, loss = self.transformer_model_train_fn(input, label) | |||||
| return { | |||||
| 'train_op': train_op, | |||||
| 'loss': loss, | |||||
| } | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """ | |||||
| Run the forward pass for a model. | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the dict of the model inputs for the forward method | |||||
| Returns: | |||||
| Dict[str, Tensor]: output from the model forward pass | |||||
| """ | |||||
| ... | |||||
| def encoding_graph(self, features, params): | def encoding_graph(self, features, params): | ||||
| src_vocab_size = params['src_vocab_size'] | src_vocab_size = params['src_vocab_size'] | ||||
| @@ -137,6 +159,278 @@ class CsanmtForTranslation(Model): | |||||
| params) | params) | ||||
| return encoder_output | return encoder_output | ||||
| def build_contrastive_training_graph(self, features, labels, params): | |||||
| # representations | |||||
| source_name = 'source' | |||||
| target_name = 'target' | |||||
| if params['shared_source_target_embedding']: | |||||
| source_name = None | |||||
| target_name = None | |||||
| feature_output = self.semantic_encoding_graph( | |||||
| features, params, name=source_name) | |||||
| label_output = self.semantic_encoding_graph( | |||||
| labels, params, name=target_name) | |||||
| return feature_output, label_output | |||||
| def MGMC_sampling(self, x_embedding, y_embedding, params, epsilon=1e-12): | |||||
| K = params['num_of_samples'] | |||||
| eta = params['eta'] | |||||
| assert K % 2 == 0 | |||||
| def get_samples(x_vector, y_vector): | |||||
| bias_vector = y_vector - x_vector | |||||
| w_r = tf.math.divide( | |||||
| tf.abs(bias_vector) - tf.reduce_min( | |||||
| input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) | |||||
| + epsilon, | |||||
| tf.reduce_max( | |||||
| input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) | |||||
| - tf.reduce_min( | |||||
| input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) | |||||
| + 2 * epsilon) | |||||
| R = [] | |||||
| for i in range(K // 2): | |||||
| omega = eta * tf.random.normal(tf.shape(input=bias_vector), 0.0, w_r) + \ | |||||
| (1.0 - eta) * tf.random.normal(tf.shape(input=bias_vector), 0.0, 1.0) | |||||
| sample = x_vector + omega * bias_vector | |||||
| R.append(sample) | |||||
| return R | |||||
| ALL_SAMPLES = [] | |||||
| ALL_SAMPLES = get_samples(x_embedding, y_embedding) | |||||
| ALL_SAMPLES.extend(get_samples(y_embedding, x_embedding)) | |||||
| assert len(ALL_SAMPLES) == K | |||||
| return tf.concat(ALL_SAMPLES, axis=0) | |||||
| def decoding_graph(self, | |||||
| encoder_output, | |||||
| encoder_self_attention_bias, | |||||
| labels, | |||||
| params={}, | |||||
| embedding_augmentation=None): | |||||
| trg_vocab_size = params['trg_vocab_size'] | |||||
| hidden_size = params['hidden_size'] | |||||
| initializer = tf.compat.v1.random_normal_initializer( | |||||
| 0.0, hidden_size**-0.5, dtype=tf.float32) | |||||
| if params['shared_source_target_embedding']: | |||||
| with tf.compat.v1.variable_scope( | |||||
| 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE): | |||||
| trg_embedding = tf.compat.v1.get_variable( | |||||
| 'Weights', [trg_vocab_size, hidden_size], | |||||
| initializer=initializer) | |||||
| else: | |||||
| with tf.compat.v1.variable_scope('Target_Embedding'): | |||||
| trg_embedding = tf.compat.v1.get_variable( | |||||
| 'Weights', [trg_vocab_size, hidden_size], | |||||
| initializer=initializer) | |||||
| eos_padding = tf.zeros([tf.shape(input=labels)[0], 1], tf.int64) | |||||
| trg_seq = tf.concat([labels, eos_padding], 1) | |||||
| trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32) | |||||
| shift_trg_mask = trg_mask[:, :-1] | |||||
| shift_trg_mask = tf.pad( | |||||
| tensor=shift_trg_mask, | |||||
| paddings=[[0, 0], [1, 0]], | |||||
| constant_values=1) | |||||
| decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32)) | |||||
| decoder_input *= hidden_size**0.5 | |||||
| decoder_self_attention_bias = attention_bias( | |||||
| tf.shape(input=decoder_input)[1], 'causal') | |||||
| decoder_input = tf.pad( | |||||
| tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :] | |||||
| if params['position_info_type'] == 'absolute': | |||||
| decoder_input = add_timing_signal(decoder_input) | |||||
| decoder_input = tf.nn.dropout( | |||||
| decoder_input, rate=1 - (1.0 - params['residual_dropout'])) | |||||
| # training | |||||
| decoder_output, attention_weights = transformer_decoder( | |||||
| decoder_input, | |||||
| encoder_output, | |||||
| decoder_self_attention_bias, | |||||
| encoder_self_attention_bias, | |||||
| states_key=None, | |||||
| states_val=None, | |||||
| embedding_augmentation=embedding_augmentation, | |||||
| params=params) | |||||
| logits = self.prediction(decoder_output, params) | |||||
| on_value = params['confidence'] | |||||
| off_value = (1.0 - params['confidence']) / tf.cast( | |||||
| trg_vocab_size - 1, dtype=tf.float32) | |||||
| soft_targets = tf.one_hot( | |||||
| tf.cast(trg_seq, tf.int32), | |||||
| depth=trg_vocab_size, | |||||
| on_value=on_value, | |||||
| off_value=off_value) | |||||
| mask = tf.cast(shift_trg_mask, logits.dtype) | |||||
| xentropy = tf.nn.softmax_cross_entropy_with_logits( | |||||
| logits=logits, labels=tf.stop_gradient(soft_targets)) * mask | |||||
| loss = tf.reduce_sum(input_tensor=xentropy) / tf.reduce_sum( | |||||
| input_tensor=mask) | |||||
| return loss | |||||
| def build_training_graph(self, | |||||
| features, | |||||
| labels, | |||||
| params, | |||||
| feature_embedding=None, | |||||
| label_embedding=None): | |||||
| # encode | |||||
| encoder_output, encoder_self_attention_bias = self.encoding_graph( | |||||
| features, params) | |||||
| embedding_augmentation = None | |||||
| if feature_embedding is not None and label_embedding is not None: | |||||
| embedding_augmentation = self.MGMC_sampling( | |||||
| feature_embedding, label_embedding, params) | |||||
| encoder_output = tf.tile(encoder_output, | |||||
| [params['num_of_samples'], 1, 1]) | |||||
| encoder_self_attention_bias = tf.tile( | |||||
| encoder_self_attention_bias, | |||||
| [params['num_of_samples'], 1, 1, 1]) | |||||
| labels = tf.tile(labels, [params['num_of_samples'], 1]) | |||||
| # decode | |||||
| loss = self.decoding_graph( | |||||
| encoder_output, | |||||
| encoder_self_attention_bias, | |||||
| labels, | |||||
| params, | |||||
| embedding_augmentation=embedding_augmentation) | |||||
| return loss | |||||
| def transformer_model_train_fn(self, features, labels): | |||||
| initializer = get_initializer(self.params) | |||||
| with tf.compat.v1.variable_scope('NmtModel', initializer=initializer): | |||||
| num_gpus = self.params['num_gpus'] | |||||
| gradient_clip_norm = self.params['gradient_clip_norm'] | |||||
| global_step = tf.compat.v1.train.get_global_step() | |||||
| print(global_step) | |||||
| # learning rate | |||||
| learning_rate = get_learning_rate_decay( | |||||
| self.params['learning_rate'], global_step, self.params) | |||||
| learning_rate = tf.convert_to_tensor( | |||||
| value=learning_rate, dtype=tf.float32) | |||||
| # optimizer | |||||
| if self.params['optimizer'] == 'sgd': | |||||
| optimizer = tf.compat.v1.train.GradientDescentOptimizer( | |||||
| learning_rate) | |||||
| elif self.params['optimizer'] == 'adam': | |||||
| optimizer = tf.compat.v1.train.AdamOptimizer( | |||||
| learning_rate=learning_rate, | |||||
| beta1=self.params['adam_beta1'], | |||||
| beta2=self.params['adam_beta2'], | |||||
| epsilon=self.params['adam_epsilon']) | |||||
| else: | |||||
| tf.compat.v1.logging.info('optimizer not supported') | |||||
| sys.exit() | |||||
| opt = MultiStepOptimizer(optimizer, self.params['update_cycle']) | |||||
| def fill_gpus(inputs, num_gpus): | |||||
| outputs = inputs | |||||
| for i in range(num_gpus): | |||||
| outputs = tf.concat([outputs, inputs], axis=0) | |||||
| outputs = outputs[:num_gpus, ] | |||||
| return outputs | |||||
| features = tf.cond( | |||||
| pred=tf.shape(input=features)[0] < num_gpus, | |||||
| true_fn=lambda: fill_gpus(features, num_gpus), | |||||
| false_fn=lambda: features) | |||||
| labels = tf.cond( | |||||
| pred=tf.shape(input=labels)[0] < num_gpus, | |||||
| true_fn=lambda: fill_gpus(labels, num_gpus), | |||||
| false_fn=lambda: labels) | |||||
| if num_gpus > 0: | |||||
| feature_shards = shard_features(features, num_gpus) | |||||
| label_shards = shard_features(labels, num_gpus) | |||||
| else: | |||||
| feature_shards = [features] | |||||
| label_shards = [labels] | |||||
| if num_gpus > 0: | |||||
| devices = ['gpu:%d' % d for d in range(num_gpus)] | |||||
| else: | |||||
| devices = ['cpu:0'] | |||||
| multi_grads = [] | |||||
| sharded_losses = [] | |||||
| for i, device in enumerate(devices): | |||||
| with tf.device(device), tf.compat.v1.variable_scope( | |||||
| tf.compat.v1.get_variable_scope(), | |||||
| reuse=True if i > 0 else None): | |||||
| with tf.name_scope('%s_%d' % ('GPU', i)): | |||||
| feature_output, label_output = self.build_contrastive_training_graph( | |||||
| feature_shards[i], label_shards[i], self.params) | |||||
| mle_loss = self.build_training_graph( | |||||
| feature_shards[i], label_shards[i], self.params, | |||||
| feature_output, label_output) | |||||
| sharded_losses.append(mle_loss) | |||||
| tf.compat.v1.summary.scalar('mle_loss_{}'.format(i), | |||||
| mle_loss) | |||||
| # Optimization | |||||
| trainable_vars_list = [ | |||||
| v for v in tf.compat.v1.trainable_variables() | |||||
| if 'Shared_Semantic_Embedding' not in v.name | |||||
| and 'mini_xlm_encoder' not in v.name | |||||
| ] | |||||
| grads_and_vars = opt.compute_gradients( | |||||
| mle_loss, | |||||
| var_list=trainable_vars_list, | |||||
| colocate_gradients_with_ops=True) | |||||
| multi_grads.append(grads_and_vars) | |||||
| total_loss = tf.add_n(sharded_losses) / len(sharded_losses) | |||||
| # Average gradients | |||||
| grads_and_vars = average_gradients(multi_grads) | |||||
| if gradient_clip_norm > 0.0: | |||||
| grads, var_list = list(zip(*grads_and_vars)) | |||||
| grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm) | |||||
| grads_and_vars = zip(grads, var_list) | |||||
| train_op = opt.apply_gradients( | |||||
| grads_and_vars, | |||||
| global_step=tf.compat.v1.train.get_global_step()) | |||||
| return train_op, total_loss | |||||
| def prediction(self, decoder_output, params): | |||||
| hidden_size = params['hidden_size'] | |||||
| trg_vocab_size = params['trg_vocab_size'] | |||||
| if params['shared_embedding_and_softmax_weights']: | |||||
| embedding_scope = 'Shared_Embedding' if params[ | |||||
| 'shared_source_target_embedding'] else 'Target_Embedding' | |||||
| with tf.compat.v1.variable_scope(embedding_scope, reuse=True): | |||||
| weights = tf.compat.v1.get_variable('Weights') | |||||
| else: | |||||
| weights = tf.compat.v1.get_variable('Softmax', | |||||
| [tgt_vocab_size, hidden_size]) | |||||
| shape = tf.shape(input=decoder_output)[:-1] | |||||
| decoder_output = tf.reshape(decoder_output, [-1, hidden_size]) | |||||
| logits = tf.matmul(decoder_output, weights, transpose_b=True) | |||||
| logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0)) | |||||
| return logits | |||||
| def inference_func(self, | def inference_func(self, | ||||
| encoder_output, | encoder_output, | ||||
| feature_output, | feature_output, | ||||
| @@ -193,7 +487,7 @@ class CsanmtForTranslation(Model): | |||||
| weights = tf.compat.v1.get_variable('Weights') | weights = tf.compat.v1.get_variable('Weights') | ||||
| else: | else: | ||||
| weights = tf.compat.v1.get_variable('Softmax', | weights = tf.compat.v1.get_variable('Softmax', | ||||
| [tgt_vocab_size, hidden_size]) | |||||
| [trg_vocab_size, hidden_size]) | |||||
| logits = tf.matmul(decoder_output_last, weights, transpose_b=True) | logits = tf.matmul(decoder_output_last, weights, transpose_b=True) | ||||
| log_prob = tf.nn.log_softmax(logits) | log_prob = tf.nn.log_softmax(logits) | ||||
| return log_prob, attention_weights_last, states_key, states_val | return log_prob, attention_weights_last, states_key, states_val | ||||
| @@ -212,7 +506,11 @@ class CsanmtForTranslation(Model): | |||||
| encoder_output, encoder_self_attention_bias = self.encoding_graph( | encoder_output, encoder_self_attention_bias = self.encoding_graph( | ||||
| features, params) | features, params) | ||||
| feature_output = self.semantic_encoding_graph(features, params) | |||||
| source_name = 'source' | |||||
| if params['shared_source_target_embedding']: | |||||
| source_name = None | |||||
| feature_output = self.semantic_encoding_graph( | |||||
| features, params, name=source_name) | |||||
| init_seqs = tf.fill([batch_size, beam_size, 1], 0) | init_seqs = tf.fill([batch_size, beam_size, 1], 0) | ||||
| init_log_probs = \ | init_log_probs = \ | ||||
| @@ -585,7 +883,6 @@ def _residual_fn(x, y, keep_prob=None): | |||||
| def embedding_augmentation_layer(x, embedding_augmentation, params, name=None): | def embedding_augmentation_layer(x, embedding_augmentation, params, name=None): | ||||
| hidden_size = params['hidden_size'] | hidden_size = params['hidden_size'] | ||||
| keep_prob = 1.0 - params['relu_dropout'] | keep_prob = 1.0 - params['relu_dropout'] | ||||
| layer_postproc = params['layer_postproc'] | |||||
| with tf.compat.v1.variable_scope( | with tf.compat.v1.variable_scope( | ||||
| name, | name, | ||||
| default_name='embedding_augmentation_layer', | default_name='embedding_augmentation_layer', | ||||
| @@ -600,8 +897,7 @@ def embedding_augmentation_layer(x, embedding_augmentation, params, name=None): | |||||
| with tf.compat.v1.variable_scope('output_layer'): | with tf.compat.v1.variable_scope('output_layer'): | ||||
| output = linear(hidden, hidden_size, True, True) | output = linear(hidden, hidden_size, True, True) | ||||
| x = _layer_process(x + output, layer_postproc) | |||||
| return x | |||||
| return x + output | |||||
| def transformer_ffn_layer(x, params, name=None): | def transformer_ffn_layer(x, params, name=None): | ||||
| @@ -740,8 +1036,10 @@ def transformer_decoder(decoder_input, | |||||
| if params['position_info_type'] == 'relative' else None | if params['position_info_type'] == 'relative' else None | ||||
| # continuous semantic augmentation | # continuous semantic augmentation | ||||
| if embedding_augmentation is not None: | if embedding_augmentation is not None: | ||||
| x = embedding_augmentation_layer(x, embedding_augmentation, | |||||
| params) | |||||
| x = embedding_augmentation_layer( | |||||
| x, _layer_process(embedding_augmentation, | |||||
| layer_preproc), params) | |||||
| x = _layer_process(x, layer_postproc) | |||||
| o, w = multihead_attention( | o, w = multihead_attention( | ||||
| _layer_process(x, layer_preproc), | _layer_process(x, layer_preproc), | ||||
| None, | None, | ||||
| @@ -1004,3 +1302,191 @@ def multihead_attention(queries, | |||||
| w = tf.reduce_mean(w, 1) | w = tf.reduce_mean(w, 1) | ||||
| x = linear(x, output_depth, True, True, scope='output_transform') | x = linear(x, output_depth, True, True, scope='output_transform') | ||||
| return x, w | return x, w | ||||
| def get_initializer(params): | |||||
| if params['initializer'] == 'uniform': | |||||
| max_val = params['initializer_scale'] | |||||
| return tf.compat.v1.random_uniform_initializer(-max_val, max_val) | |||||
| elif params['initializer'] == 'normal': | |||||
| return tf.compat.v1.random_normal_initializer( | |||||
| 0.0, params['initializer_scale']) | |||||
| elif params['initializer'] == 'normal_unit_scaling': | |||||
| return tf.compat.v1.variance_scaling_initializer( | |||||
| params['initializer_scale'], mode='fan_avg', distribution='normal') | |||||
| elif params['initializer'] == 'uniform_unit_scaling': | |||||
| return tf.compat.v1.variance_scaling_initializer( | |||||
| params['initializer_scale'], | |||||
| mode='fan_avg', | |||||
| distribution='uniform') | |||||
| else: | |||||
| raise ValueError('Unrecognized initializer: %s' | |||||
| % params['initializer']) | |||||
| def get_learning_rate_decay(learning_rate, global_step, params): | |||||
| if params['learning_rate_decay'] in ['linear_warmup_rsqrt_decay', 'noam']: | |||||
| step = tf.cast(global_step, dtype=tf.float32) | |||||
| warmup_steps = tf.cast(params['warmup_steps'], dtype=tf.float32) | |||||
| multiplier = params['hidden_size']**-0.5 | |||||
| decay = multiplier * tf.minimum((step + 1) * (warmup_steps**-1.5), | |||||
| (step + 1)**-0.5) | |||||
| return learning_rate * decay | |||||
| elif params['learning_rate_decay'] == 'piecewise_constant': | |||||
| return tf.compat.v1.train.piecewise_constant( | |||||
| tf.cast(global_step, dtype=tf.int32), | |||||
| params['learning_rate_boundaries'], params['learning_rate_values']) | |||||
| elif params['learning_rate_decay'] == 'none': | |||||
| return learning_rate | |||||
| else: | |||||
| raise ValueError('Unknown learning_rate_decay') | |||||
| def average_gradients(tower_grads): | |||||
| average_grads = [] | |||||
| for grad_and_vars in zip(*tower_grads): | |||||
| grads = [] | |||||
| for g, _ in grad_and_vars: | |||||
| expanded_g = tf.expand_dims(g, 0) | |||||
| grads.append(expanded_g) | |||||
| grad = tf.concat(axis=0, values=grads) | |||||
| grad = tf.reduce_mean(grad, 0) | |||||
| v = grad_and_vars[0][1] | |||||
| grad_and_var = (grad, v) | |||||
| average_grads.append(grad_and_var) | |||||
| return average_grads | |||||
| _ENGINE = None | |||||
| def all_reduce(tensor): | |||||
| if _ENGINE is None: | |||||
| return tensor | |||||
| return _ENGINE.allreduce(tensor, compression=_ENGINE.Compression.fp16) | |||||
| class MultiStepOptimizer(tf.compat.v1.train.Optimizer): | |||||
| def __init__(self, | |||||
| optimizer, | |||||
| step=1, | |||||
| use_locking=False, | |||||
| name='MultiStepOptimizer'): | |||||
| super(MultiStepOptimizer, self).__init__(use_locking, name) | |||||
| self._optimizer = optimizer | |||||
| self._step = step | |||||
| self._step_t = tf.convert_to_tensor(step, name='step') | |||||
| def _all_reduce(self, tensor): | |||||
| with tf.name_scope(self._name + '_Allreduce'): | |||||
| if tensor is None: | |||||
| return tensor | |||||
| if isinstance(tensor, tf.IndexedSlices): | |||||
| tensor = tf.convert_to_tensor(tensor) | |||||
| return all_reduce(tensor) | |||||
| def compute_gradients(self, | |||||
| loss, | |||||
| var_list=None, | |||||
| gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP, | |||||
| aggregation_method=None, | |||||
| colocate_gradients_with_ops=False, | |||||
| grad_loss=None): | |||||
| grads_and_vars = self._optimizer.compute_gradients( | |||||
| loss, var_list, gate_gradients, aggregation_method, | |||||
| colocate_gradients_with_ops, grad_loss) | |||||
| grads, var_list = list(zip(*grads_and_vars)) | |||||
| # Do not create extra variables when step is 1 | |||||
| if self._step == 1: | |||||
| grads = [self._all_reduce(t) for t in grads] | |||||
| return list(zip(grads, var_list)) | |||||
| first_var = min(var_list, key=lambda x: x.name) | |||||
| iter_var = self._create_non_slot_variable( | |||||
| initial_value=0 if self._step == 1 else 1, | |||||
| name='iter', | |||||
| colocate_with=first_var) | |||||
| new_grads = [] | |||||
| for grad, var in zip(grads, var_list): | |||||
| grad_acc = self._zeros_slot(var, 'grad_acc', self._name) | |||||
| if isinstance(grad, tf.IndexedSlices): | |||||
| grad_acc = tf.scatter_add( | |||||
| grad_acc, | |||||
| grad.indices, | |||||
| grad.values, | |||||
| use_locking=self._use_locking) | |||||
| else: | |||||
| grad_acc = tf.assign_add( | |||||
| grad_acc, grad, use_locking=self._use_locking) | |||||
| def _acc_grad(): | |||||
| return grad_acc | |||||
| def _avg_grad(): | |||||
| return self._all_reduce(grad_acc / self._step) | |||||
| grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad) | |||||
| new_grads.append(grad) | |||||
| return list(zip(new_grads, var_list)) | |||||
| def apply_gradients(self, grads_and_vars, global_step=None, name=None): | |||||
| if self._step == 1: | |||||
| return self._optimizer.apply_gradients( | |||||
| grads_and_vars, global_step, name=name) | |||||
| grads, var_list = list(zip(*grads_and_vars)) | |||||
| def _pass_gradients(): | |||||
| return tf.group(*grads) | |||||
| def _apply_gradients(): | |||||
| op = self._optimizer.apply_gradients( | |||||
| zip(grads, var_list), global_step, name) | |||||
| with tf.control_dependencies([op]): | |||||
| zero_ops = [] | |||||
| for var in var_list: | |||||
| grad_acc = self.get_slot(var, 'grad_acc') | |||||
| zero_ops.append( | |||||
| grad_acc.assign( | |||||
| tf.zeros_like(grad_acc), | |||||
| use_locking=self._use_locking)) | |||||
| zero_op = tf.group(*zero_ops) | |||||
| return tf.group(*[op, zero_op]) | |||||
| iter_var = self._get_non_slot_variable('iter', tf.get_default_graph()) | |||||
| update_op = tf.cond( | |||||
| tf.equal(iter_var, 0), _apply_gradients, _pass_gradients) | |||||
| with tf.control_dependencies([update_op]): | |||||
| iter_op = iter_var.assign( | |||||
| tf.mod(iter_var + 1, self._step_t), | |||||
| use_locking=self._use_locking) | |||||
| return tf.group(*[update_op, iter_op]) | |||||
| def shard_features(x, num_datashards): | |||||
| x = tf.convert_to_tensor(x) | |||||
| batch_size = tf.shape(x)[0] | |||||
| size_splits = [] | |||||
| with tf.device('/cpu:0'): | |||||
| for i in range(num_datashards): | |||||
| size_splits.append( | |||||
| tf.cond( | |||||
| tf.greater( | |||||
| tf.compat.v1.mod(batch_size, num_datashards), | |||||
| i), lambda: batch_size // num_datashards + 1, | |||||
| lambda: batch_size // num_datashards)) | |||||
| return tf.split(x, size_splits, axis=0) | |||||
| @@ -27,7 +27,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| (Pipelines.sentence_similarity, | (Pipelines.sentence_similarity, | ||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| Tasks.translation: (Pipelines.csanmt_translation, | Tasks.translation: (Pipelines.csanmt_translation, | ||||
| 'damo/nlp_csanmt_translation'), | |||||
| 'damo/nlp_csanmt_translation_zh2en'), | |||||
| Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | ||||
| Tasks.sentiment_classification: | Tasks.sentiment_classification: | ||||
| (Pipelines.sentiment_classification, | (Pipelines.sentiment_classification, | ||||
| @@ -1,4 +1,5 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from threading import Lock | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import numpy as np | import numpy as np | ||||
| @@ -8,59 +9,38 @@ from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import Frameworks, ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| if tf.__version__ >= '2.0': | if tf.__version__ >= '2.0': | ||||
| tf = tf.compat.v1 | tf = tf.compat.v1 | ||||
| tf.disable_eager_execution() | |||||
| tf.disable_eager_execution() | |||||
| logger = get_logger() | logger = get_logger() | ||||
| __all__ = ['TranslationPipeline'] | __all__ = ['TranslationPipeline'] | ||||
| # constant | |||||
| PARAMS = { | |||||
| 'hidden_size': 512, | |||||
| 'filter_size': 2048, | |||||
| 'num_heads': 8, | |||||
| 'num_encoder_layers': 6, | |||||
| 'num_decoder_layers': 6, | |||||
| 'attention_dropout': 0.0, | |||||
| 'residual_dropout': 0.0, | |||||
| 'relu_dropout': 0.0, | |||||
| 'layer_preproc': 'none', | |||||
| 'layer_postproc': 'layer_norm', | |||||
| 'shared_embedding_and_softmax_weights': True, | |||||
| 'shared_source_target_embedding': True, | |||||
| 'initializer_scale': 0.1, | |||||
| 'train_max_len': 100, | |||||
| 'confidence': 0.9, | |||||
| 'position_info_type': 'absolute', | |||||
| 'max_relative_dis': 16, | |||||
| 'beam_size': 4, | |||||
| 'lp_rate': 0.6, | |||||
| 'num_semantic_encoder_layers': 4, | |||||
| 'max_decoded_trg_len': 100, | |||||
| 'src_vocab_size': 37006, | |||||
| 'trg_vocab_size': 37006, | |||||
| 'vocab_src': 'src_vocab.txt', | |||||
| 'vocab_trg': 'trg_vocab.txt' | |||||
| } | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.translation, module_name=Pipelines.csanmt_translation) | Tasks.translation, module_name=Pipelines.csanmt_translation) | ||||
| class TranslationPipeline(Pipeline): | class TranslationPipeline(Pipeline): | ||||
| def __init__(self, model: str, **kwargs): | def __init__(self, model: str, **kwargs): | ||||
| super().__init__(model=model) | |||||
| model = self.model.model_dir | |||||
| tf.reset_default_graph() | tf.reset_default_graph() | ||||
| self.framework = Frameworks.tf | |||||
| self.device_name = 'cpu' | |||||
| super().__init__(model=model) | |||||
| model_path = osp.join( | model_path = osp.join( | ||||
| osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') | osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') | ||||
| self.params = PARAMS | |||||
| self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION)) | |||||
| self.params = {} | |||||
| self._override_params_from_file() | |||||
| self._src_vocab_path = osp.join(model, self.params['vocab_src']) | self._src_vocab_path = osp.join(model, self.params['vocab_src']) | ||||
| self._src_vocab = dict([ | self._src_vocab = dict([ | ||||
| (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) | (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) | ||||
| @@ -70,15 +50,16 @@ class TranslationPipeline(Pipeline): | |||||
| (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) | (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) | ||||
| ]) | ]) | ||||
| config = tf.ConfigProto(allow_soft_placement=True) | |||||
| config.gpu_options.allow_growth = True | |||||
| self._session = tf.Session(config=config) | |||||
| tf_config = tf.ConfigProto(allow_soft_placement=True) | |||||
| tf_config.gpu_options.allow_growth = True | |||||
| self._session = tf.Session(config=tf_config) | |||||
| self.input_wids = tf.placeholder( | self.input_wids = tf.placeholder( | ||||
| dtype=tf.int64, shape=[None, None], name='input_wids') | dtype=tf.int64, shape=[None, None], name='input_wids') | ||||
| self.output = {} | self.output = {} | ||||
| # model | # model | ||||
| self.model = CsanmtForTranslation(model_path, params=self.params) | |||||
| output = self.model(self.input_wids) | output = self.model(self.input_wids) | ||||
| self.output.update(output) | self.output.update(output) | ||||
| @@ -88,6 +69,49 @@ class TranslationPipeline(Pipeline): | |||||
| model_loader = tf.train.Saver(tf.global_variables()) | model_loader = tf.train.Saver(tf.global_variables()) | ||||
| model_loader.restore(sess, model_path) | model_loader.restore(sess, model_path) | ||||
| def _override_params_from_file(self): | |||||
| # model | |||||
| self.params['hidden_size'] = self.cfg['model']['hidden_size'] | |||||
| self.params['filter_size'] = self.cfg['model']['filter_size'] | |||||
| self.params['num_heads'] = self.cfg['model']['num_heads'] | |||||
| self.params['num_encoder_layers'] = self.cfg['model'][ | |||||
| 'num_encoder_layers'] | |||||
| self.params['num_decoder_layers'] = self.cfg['model'][ | |||||
| 'num_decoder_layers'] | |||||
| self.params['layer_preproc'] = self.cfg['model']['layer_preproc'] | |||||
| self.params['layer_postproc'] = self.cfg['model']['layer_postproc'] | |||||
| self.params['shared_embedding_and_softmax_weights'] = self.cfg[ | |||||
| 'model']['shared_embedding_and_softmax_weights'] | |||||
| self.params['shared_source_target_embedding'] = self.cfg['model'][ | |||||
| 'shared_source_target_embedding'] | |||||
| self.params['initializer_scale'] = self.cfg['model'][ | |||||
| 'initializer_scale'] | |||||
| self.params['position_info_type'] = self.cfg['model'][ | |||||
| 'position_info_type'] | |||||
| self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis'] | |||||
| self.params['num_semantic_encoder_layers'] = self.cfg['model'][ | |||||
| 'num_semantic_encoder_layers'] | |||||
| self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size'] | |||||
| self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size'] | |||||
| self.params['attention_dropout'] = 0.0 | |||||
| self.params['residual_dropout'] = 0.0 | |||||
| self.params['relu_dropout'] = 0.0 | |||||
| # dataset | |||||
| self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file'] | |||||
| self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file'] | |||||
| # train | |||||
| self.params['train_max_len'] = self.cfg['train']['train_max_len'] | |||||
| self.params['confidence'] = self.cfg['train']['confidence'] | |||||
| # evaluation | |||||
| self.params['beam_size'] = self.cfg['evaluation']['beam_size'] | |||||
| self.params['lp_rate'] = self.cfg['evaluation']['lp_rate'] | |||||
| self.params['max_decoded_trg_len'] = self.cfg['evaluation'][ | |||||
| 'max_decoded_trg_len'] | |||||
| def preprocess(self, input: str) -> Dict[str, Any]: | def preprocess(self, input: str) -> Dict[str, Any]: | ||||
| input_ids = np.array([[ | input_ids = np.array([[ | ||||
| self._src_vocab[w] | self._src_vocab[w] | ||||
| @@ -5,10 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .sequence_classification_trainer import SequenceClassificationTrainer | from .sequence_classification_trainer import SequenceClassificationTrainer | ||||
| from .csanmt_translation_trainer import CsanmtTranslationTrainer | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'] | |||||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | |||||
| 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,324 @@ | |||||
| import os.path as osp | |||||
| from typing import Dict, Optional | |||||
| import tensorflow as tf | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.nlp import CsanmtForTranslation | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| if tf.__version__ >= '2.0': | |||||
| tf = tf.compat.v1 | |||||
| tf.disable_eager_execution() | |||||
| logger = get_logger() | |||||
| @TRAINERS.register_module(module_name=r'csanmt-translation') | |||||
| class CsanmtTranslationTrainer(BaseTrainer): | |||||
| def __init__(self, model: str, cfg_file: str = None, *args, **kwargs): | |||||
| if not osp.exists(model): | |||||
| model = snapshot_download(model) | |||||
| tf.reset_default_graph() | |||||
| self.model_dir = model | |||||
| self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER) | |||||
| if cfg_file is None: | |||||
| cfg_file = osp.join(model, ModelFile.CONFIGURATION) | |||||
| super().__init__(cfg_file) | |||||
| self.params = {} | |||||
| self._override_params_from_file() | |||||
| tf_config = tf.ConfigProto(allow_soft_placement=True) | |||||
| tf_config.gpu_options.allow_growth = True | |||||
| self._session = tf.Session(config=tf_config) | |||||
| self.source_wids = tf.placeholder( | |||||
| dtype=tf.int64, shape=[None, None], name='source_wids') | |||||
| self.target_wids = tf.placeholder( | |||||
| dtype=tf.int64, shape=[None, None], name='target_wids') | |||||
| self.output = {} | |||||
| self.global_step = tf.train.create_global_step() | |||||
| self.model = CsanmtForTranslation(self.model_path, params=self.params) | |||||
| output = self.model(input=self.source_wids, label=self.target_wids) | |||||
| self.output.update(output) | |||||
| self.model_saver = tf.train.Saver( | |||||
| tf.global_variables(), | |||||
| max_to_keep=self.params['keep_checkpoint_max']) | |||||
| with self._session.as_default() as sess: | |||||
| logger.info(f'loading model from {self.model_path}') | |||||
| pretrained_variables_map = get_pretrained_variables_map( | |||||
| self.model_path) | |||||
| tf.train.init_from_checkpoint(self.model_path, | |||||
| pretrained_variables_map) | |||||
| sess.run(tf.global_variables_initializer()) | |||||
| def _override_params_from_file(self): | |||||
| self.params['hidden_size'] = self.cfg['model']['hidden_size'] | |||||
| self.params['filter_size'] = self.cfg['model']['filter_size'] | |||||
| self.params['num_heads'] = self.cfg['model']['num_heads'] | |||||
| self.params['num_encoder_layers'] = self.cfg['model'][ | |||||
| 'num_encoder_layers'] | |||||
| self.params['num_decoder_layers'] = self.cfg['model'][ | |||||
| 'num_decoder_layers'] | |||||
| self.params['layer_preproc'] = self.cfg['model']['layer_preproc'] | |||||
| self.params['layer_postproc'] = self.cfg['model']['layer_postproc'] | |||||
| self.params['shared_embedding_and_softmax_weights'] = self.cfg[ | |||||
| 'model']['shared_embedding_and_softmax_weights'] | |||||
| self.params['shared_source_target_embedding'] = self.cfg['model'][ | |||||
| 'shared_source_target_embedding'] | |||||
| self.params['initializer_scale'] = self.cfg['model'][ | |||||
| 'initializer_scale'] | |||||
| self.params['position_info_type'] = self.cfg['model'][ | |||||
| 'position_info_type'] | |||||
| self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis'] | |||||
| self.params['num_semantic_encoder_layers'] = self.cfg['model'][ | |||||
| 'num_semantic_encoder_layers'] | |||||
| self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size'] | |||||
| self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size'] | |||||
| self.params['attention_dropout'] = 0.0 | |||||
| self.params['residual_dropout'] = 0.0 | |||||
| self.params['relu_dropout'] = 0.0 | |||||
| self.params['train_src'] = self.cfg['dataset']['train_src'] | |||||
| self.params['train_trg'] = self.cfg['dataset']['train_trg'] | |||||
| self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file'] | |||||
| self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file'] | |||||
| self.params['num_gpus'] = self.cfg['train']['num_gpus'] | |||||
| self.params['warmup_steps'] = self.cfg['train']['warmup_steps'] | |||||
| self.params['update_cycle'] = self.cfg['train']['update_cycle'] | |||||
| self.params['keep_checkpoint_max'] = self.cfg['train'][ | |||||
| 'keep_checkpoint_max'] | |||||
| self.params['confidence'] = self.cfg['train']['confidence'] | |||||
| self.params['optimizer'] = self.cfg['train']['optimizer'] | |||||
| self.params['adam_beta1'] = self.cfg['train']['adam_beta1'] | |||||
| self.params['adam_beta2'] = self.cfg['train']['adam_beta2'] | |||||
| self.params['adam_epsilon'] = self.cfg['train']['adam_epsilon'] | |||||
| self.params['gradient_clip_norm'] = self.cfg['train'][ | |||||
| 'gradient_clip_norm'] | |||||
| self.params['learning_rate_decay'] = self.cfg['train'][ | |||||
| 'learning_rate_decay'] | |||||
| self.params['initializer'] = self.cfg['train']['initializer'] | |||||
| self.params['initializer_scale'] = self.cfg['train'][ | |||||
| 'initializer_scale'] | |||||
| self.params['learning_rate'] = self.cfg['train']['learning_rate'] | |||||
| self.params['train_batch_size_words'] = self.cfg['train'][ | |||||
| 'train_batch_size_words'] | |||||
| self.params['scale_l1'] = self.cfg['train']['scale_l1'] | |||||
| self.params['scale_l2'] = self.cfg['train']['scale_l2'] | |||||
| self.params['train_max_len'] = self.cfg['train']['train_max_len'] | |||||
| self.params['max_training_steps'] = self.cfg['train'][ | |||||
| 'max_training_steps'] | |||||
| self.params['save_checkpoints_steps'] = self.cfg['train'][ | |||||
| 'save_checkpoints_steps'] | |||||
| self.params['num_of_samples'] = self.cfg['train']['num_of_samples'] | |||||
| self.params['eta'] = self.cfg['train']['eta'] | |||||
| self.params['beam_size'] = self.cfg['evaluation']['beam_size'] | |||||
| self.params['lp_rate'] = self.cfg['evaluation']['lp_rate'] | |||||
| self.params['max_decoded_trg_len'] = self.cfg['evaluation'][ | |||||
| 'max_decoded_trg_len'] | |||||
| self.params['seed'] = self.cfg['model']['seed'] | |||||
| def train(self, *args, **kwargs): | |||||
| logger.info('Begin csanmt training') | |||||
| train_src = osp.join(self.model_dir, self.params['train_src']) | |||||
| train_trg = osp.join(self.model_dir, self.params['train_trg']) | |||||
| vocab_src = osp.join(self.model_dir, self.params['vocab_src']) | |||||
| vocab_trg = osp.join(self.model_dir, self.params['vocab_trg']) | |||||
| iteration = 0 | |||||
| with self._session.as_default() as tf_session: | |||||
| while True: | |||||
| iteration += 1 | |||||
| if iteration >= self.params['max_training_steps']: | |||||
| break | |||||
| train_input_fn = input_fn( | |||||
| train_src, | |||||
| train_trg, | |||||
| vocab_src, | |||||
| vocab_trg, | |||||
| batch_size_words=self.params['train_batch_size_words'], | |||||
| max_len=self.params['train_max_len'], | |||||
| num_gpus=self.params['num_gpus'] | |||||
| if self.params['num_gpus'] > 0 else 1, | |||||
| is_train=True, | |||||
| session=tf_session, | |||||
| iteration=iteration) | |||||
| features, labels = train_input_fn | |||||
| features_batch, labels_batch = tf_session.run( | |||||
| [features, labels]) | |||||
| feed_dict = { | |||||
| self.source_wids: features_batch, | |||||
| self.target_wids: labels_batch | |||||
| } | |||||
| sess_outputs = self._session.run( | |||||
| self.output, feed_dict=feed_dict) | |||||
| loss_step = sess_outputs['loss'] | |||||
| logger.info('Iteration: {}, step loss: {:.6f}'.format( | |||||
| iteration, loss_step)) | |||||
| if iteration % self.params['save_checkpoints_steps'] == 0: | |||||
| tf.logging.info('%s: Saving model on step: %d.' % | |||||
| (__name__, iteration)) | |||||
| ck_path = self.model_dir + 'model.ckpt' | |||||
| self.model_saver.save( | |||||
| tf_session, | |||||
| ck_path, | |||||
| global_step=tf.train.get_global_step()) | |||||
| tf.logging.info('%s: NMT training completed at time: %s.') | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| """evaluate a dataset | |||||
| evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` | |||||
| does not exist, read from the config file. | |||||
| Args: | |||||
| checkpoint_path (Optional[str], optional): the model path. Defaults to None. | |||||
| Returns: | |||||
| Dict[str, float]: the results about the evaluation | |||||
| Example: | |||||
| {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||||
| """ | |||||
| pass | |||||
| def input_fn(src_file, | |||||
| trg_file, | |||||
| src_vocab_file, | |||||
| trg_vocab_file, | |||||
| num_buckets=20, | |||||
| max_len=100, | |||||
| batch_size=200, | |||||
| batch_size_words=4096, | |||||
| num_gpus=1, | |||||
| is_train=True, | |||||
| session=None, | |||||
| iteration=None): | |||||
| src_vocab = tf.lookup.StaticVocabularyTable( | |||||
| tf.lookup.TextFileInitializer( | |||||
| src_vocab_file, | |||||
| key_dtype=tf.string, | |||||
| key_index=tf.lookup.TextFileIndex.WHOLE_LINE, | |||||
| value_dtype=tf.int64, | |||||
| value_index=tf.lookup.TextFileIndex.LINE_NUMBER), | |||||
| num_oov_buckets=1) # NOTE unk-> vocab_size | |||||
| trg_vocab = tf.lookup.StaticVocabularyTable( | |||||
| tf.lookup.TextFileInitializer( | |||||
| trg_vocab_file, | |||||
| key_dtype=tf.string, | |||||
| key_index=tf.lookup.TextFileIndex.WHOLE_LINE, | |||||
| value_dtype=tf.int64, | |||||
| value_index=tf.lookup.TextFileIndex.LINE_NUMBER), | |||||
| num_oov_buckets=1) # NOTE unk-> vocab_size | |||||
| src_dataset = tf.data.TextLineDataset(src_file) | |||||
| trg_dataset = tf.data.TextLineDataset(trg_file) | |||||
| src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset)) | |||||
| src_trg_dataset = src_trg_dataset.map( | |||||
| lambda src, trg: | |||||
| (tf.string_split([src]).values, tf.string_split([trg]).values), | |||||
| num_parallel_calls=10).prefetch(1000000) | |||||
| src_trg_dataset = src_trg_dataset.map( | |||||
| lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)), | |||||
| num_parallel_calls=10).prefetch(1000000) | |||||
| if is_train: | |||||
| def key_func(src_data, trg_data): | |||||
| bucket_width = (max_len + num_buckets - 1) // num_buckets | |||||
| bucket_id = tf.maximum( | |||||
| tf.size(input=src_data) // bucket_width, | |||||
| tf.size(input=trg_data) // bucket_width) | |||||
| return tf.cast(tf.minimum(num_buckets, bucket_id), dtype=tf.int64) | |||||
| def reduce_func(unused_key, windowed_data): | |||||
| return windowed_data.padded_batch( | |||||
| batch_size_words, padded_shapes=([None], [None])) | |||||
| def window_size_func(key): | |||||
| bucket_width = (max_len + num_buckets - 1) // num_buckets | |||||
| key += 1 | |||||
| size = (num_gpus * batch_size_words // (key * bucket_width)) | |||||
| return tf.cast(size, dtype=tf.int64) | |||||
| src_trg_dataset = src_trg_dataset.filter( | |||||
| lambda src, trg: tf.logical_and( | |||||
| tf.size(input=src) <= max_len, | |||||
| tf.size(input=trg) <= max_len)) | |||||
| src_trg_dataset = src_trg_dataset.apply( | |||||
| tf.data.experimental.group_by_window( | |||||
| key_func=key_func, | |||||
| reduce_func=reduce_func, | |||||
| window_size_func=window_size_func)) | |||||
| else: | |||||
| src_trg_dataset = src_trg_dataset.padded_batch( | |||||
| batch_size * num_gpus, padded_shapes=([None], [None])) | |||||
| iterator = tf.data.make_initializable_iterator(src_trg_dataset) | |||||
| tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) | |||||
| features, labels = iterator.get_next() | |||||
| if is_train: | |||||
| session.run(iterator.initializer) | |||||
| if iteration == 1: | |||||
| session.run(tf.tables_initializer()) | |||||
| return features, labels | |||||
| def get_pretrained_variables_map(checkpoint_file_path, ignore_scope=None): | |||||
| reader = tf.train.NewCheckpointReader( | |||||
| tf.train.latest_checkpoint(checkpoint_file_path)) | |||||
| saved_shapes = reader.get_variable_to_shape_map() | |||||
| if ignore_scope is None: | |||||
| var_names = sorted([(var.name, var.name.split(':')[0]) | |||||
| for var in tf.global_variables() | |||||
| if var.name.split(':')[0] in saved_shapes]) | |||||
| else: | |||||
| var_names = sorted([(var.name, var.name.split(':')[0]) | |||||
| for var in tf.global_variables() | |||||
| if var.name.split(':')[0] in saved_shapes and all( | |||||
| scope not in var.name | |||||
| for scope in ignore_scope)]) | |||||
| restore_vars = [] | |||||
| name2var = dict( | |||||
| zip( | |||||
| map(lambda x: x.name.split(':')[0], tf.global_variables()), | |||||
| tf.global_variables())) | |||||
| restore_map = {} | |||||
| with tf.variable_scope('', reuse=True): | |||||
| for var_name, saved_var_name in var_names: | |||||
| curr_var = name2var[saved_var_name] | |||||
| var_shape = curr_var.get_shape().as_list() | |||||
| if var_shape == saved_shapes[saved_var_name]: | |||||
| restore_vars.append(curr_var) | |||||
| restore_map[saved_var_name] = curr_var | |||||
| tf.logging.info('Restore paramter %s from %s ...' % | |||||
| (saved_var_name, checkpoint_file_path)) | |||||
| return restore_map | |||||
| @@ -4,19 +4,25 @@ import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pipelines.nlp import TranslationPipeline | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class TranslationTest(unittest.TestCase): | class TranslationTest(unittest.TestCase): | ||||
| model_id = 'damo/nlp_csanmt_translation' | |||||
| inputs = 'Gut@@ ach : Incre@@ ased safety for pedestri@@ ans' | |||||
| model_id = 'damo/nlp_csanmt_translation_zh2en' | |||||
| inputs = '声明 补充 说 , 沃伦 的 同事 都 深感 震惊 , 并且 希望 他 能够 投@@ 案@@ 自@@ 首 。' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_model_name(self): | def test_run_with_model_name(self): | ||||
| pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) | pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) | ||||
| print(pipeline_ins(input=self.inputs)) | print(pipeline_ins(input=self.inputs)) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.translation) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -0,0 +1,18 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.trainers.nlp import CsanmtTranslationTrainer | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TranslationTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_csanmt_translation_zh2en' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| trainer = CsanmtTranslationTrainer(model=self.model_id) | |||||
| trainer.train() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||