Browse Source

[to #42322933] solve memory error for translation finetune

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10713843

    * [to #42322933] solve memory error for translation finetune
master
xiangpeng.wxp yingda.chen 3 years ago
parent
commit
d6ea41fb70
2 changed files with 50 additions and 35 deletions
  1. +41
    -32
      modelscope/trainers/nlp/csanmt_translation_trainer.py
  2. +9
    -3
      tests/trainers/test_translation_trainer.py

+ 41
- 32
modelscope/trainers/nlp/csanmt_translation_trainer.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os.path as osp import os.path as osp
import time
from typing import Dict, Optional from typing import Dict, Optional


import tensorflow as tf import tensorflow as tf
@@ -122,8 +123,7 @@ class CsanmtTranslationTrainer(BaseTrainer):
self.params['scale_l1'] = self.cfg['train']['scale_l1'] self.params['scale_l1'] = self.cfg['train']['scale_l1']
self.params['scale_l2'] = self.cfg['train']['scale_l2'] self.params['scale_l2'] = self.cfg['train']['scale_l2']
self.params['train_max_len'] = self.cfg['train']['train_max_len'] self.params['train_max_len'] = self.cfg['train']['train_max_len']
self.params['max_training_steps'] = self.cfg['train'][
'max_training_steps']
self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs']
self.params['save_checkpoints_steps'] = self.cfg['train'][ self.params['save_checkpoints_steps'] = self.cfg['train'][
'save_checkpoints_steps'] 'save_checkpoints_steps']
self.params['num_of_samples'] = self.cfg['train']['num_of_samples'] self.params['num_of_samples'] = self.cfg['train']['num_of_samples']
@@ -144,14 +144,15 @@ class CsanmtTranslationTrainer(BaseTrainer):
vocab_src = osp.join(self.model_dir, self.params['vocab_src']) vocab_src = osp.join(self.model_dir, self.params['vocab_src'])
vocab_trg = osp.join(self.model_dir, self.params['vocab_trg']) vocab_trg = osp.join(self.model_dir, self.params['vocab_trg'])


epoch = 0
iteration = 0 iteration = 0


with self._session.as_default() as tf_session: with self._session.as_default() as tf_session:
while True: while True:
iteration += 1
if iteration >= self.params['max_training_steps']:
epoch += 1
if epoch >= self.params['num_of_epochs']:
break break
tf.logging.info('%s: Epoch %i' % (__name__, epoch))
train_input_fn = input_fn( train_input_fn = input_fn(
train_src, train_src,
train_trg, train_trg,
@@ -160,36 +161,44 @@ class CsanmtTranslationTrainer(BaseTrainer):
batch_size_words=self.params['train_batch_size_words'], batch_size_words=self.params['train_batch_size_words'],
max_len=self.params['train_max_len'], max_len=self.params['train_max_len'],
num_gpus=self.params['num_gpus'] num_gpus=self.params['num_gpus']
if self.params['num_gpus'] > 0 else 1,
if self.params['num_gpus'] > 1 else 1,
is_train=True, is_train=True,
session=tf_session, session=tf_session,
iteration=iteration)
epoch=epoch)


features, labels = train_input_fn features, labels = train_input_fn


features_batch, labels_batch = tf_session.run(
[features, labels])

feed_dict = {
self.source_wids: features_batch,
self.target_wids: labels_batch
}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
loss_step = sess_outputs['loss']
logger.info('Iteration: {}, step loss: {:.6f}'.format(
iteration, loss_step))

if iteration % self.params['save_checkpoints_steps'] == 0:
tf.logging.info('%s: Saving model on step: %d.' %
(__name__, iteration))
ck_path = self.model_dir + 'model.ckpt'
self.model_saver.save(
tf_session,
ck_path,
global_step=tf.train.get_global_step())

tf.logging.info('%s: NMT training completed at time: %s.')
try:
while True:
features_batch, labels_batch = tf_session.run(
[features, labels])
iteration += 1
feed_dict = {
self.source_wids: features_batch,
self.target_wids: labels_batch
}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
loss_step = sess_outputs['loss']
logger.info('Iteration: {}, step loss: {:.6f}'.format(
iteration, loss_step))

if iteration % self.params[
'save_checkpoints_steps'] == 0:
tf.logging.info('%s: Saving model on step: %d.' %
(__name__, iteration))
ck_path = self.model_dir + 'model.ckpt'
self.model_saver.save(
tf_session,
ck_path,
global_step=tf.train.get_global_step())

except tf.errors.OutOfRangeError:
tf.logging.info('epoch %d end!' % (epoch))

tf.logging.info(
'%s: NMT training completed at time: %s.' %
(__name__, time.asctime(time.localtime(time.time()))))


def evaluate(self, def evaluate(self,
checkpoint_path: Optional[str] = None, checkpoint_path: Optional[str] = None,
@@ -222,7 +231,7 @@ def input_fn(src_file,
num_gpus=1, num_gpus=1,
is_train=True, is_train=True,
session=None, session=None,
iteration=None):
epoch=None):
src_vocab = tf.lookup.StaticVocabularyTable( src_vocab = tf.lookup.StaticVocabularyTable(
tf.lookup.TextFileInitializer( tf.lookup.TextFileInitializer(
src_vocab_file, src_vocab_file,
@@ -291,7 +300,7 @@ def input_fn(src_file,


if is_train: if is_train:
session.run(iterator.initializer) session.run(iterator.initializer)
if iteration == 1:
if epoch == 1:
session.run(tf.tables_initializer()) session.run(tf.tables_initializer())
return features, labels return features, labels




+ 9
- 3
tests/trainers/test_translation_trainer.py View File

@@ -6,11 +6,17 @@ from modelscope.utils.test_utils import test_level




class TranslationTest(unittest.TestCase): class TranslationTest(unittest.TestCase):
model_id = 'damo/nlp_csanmt_translation_zh2en'


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
trainer = CsanmtTranslationTrainer(model=self.model_id)
def test_run_with_model_name_for_en2zh(self):
model_id = 'damo/nlp_csanmt_translation_en2zh'
trainer = CsanmtTranslationTrainer(model=model_id)
trainer.train()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_en2fr(self):
model_id = 'damo/nlp_csanmt_translation_en2fr'
trainer = CsanmtTranslationTrainer(model=model_id)
trainer.train() trainer.train()






Loading…
Cancel
Save