| @@ -131,7 +131,7 @@ class OFATrainer(EpochBasedTrainer): | |||||
| model.train() | model.train() | ||||
| # model_outputs = model.forward(inputs) | # model_outputs = model.forward(inputs) | ||||
| loss, sample_size, logging_output = self.criterion(model, inputs) | loss, sample_size, logging_output = self.criterion(model, inputs) | ||||
| train_outputs = {'loss': loss / 100} | |||||
| train_outputs = {'loss': loss} | |||||
| # add model output info to log | # add model output info to log | ||||
| if 'log_vars' not in train_outputs: | if 'log_vars' not in train_outputs: | ||||
| default_keys_pattern = ['loss'] | default_keys_pattern = ['loss'] | ||||
| @@ -144,7 +144,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| sample_size = ( | sample_size = ( | ||||
| sample['target'].size(0) if self.sentence_avg else ntokens) | sample['target'].size(0) if self.sentence_avg else ntokens) | ||||
| logging_output = { | logging_output = { | ||||
| 'loss': loss.data / 100, | |||||
| 'loss': loss.data, | |||||
| 'nll_loss': nll_loss.data, | 'nll_loss': nll_loss.data, | ||||
| 'ntokens': sample['ntokens'], | 'ntokens': sample['ntokens'], | ||||
| 'nsentences': sample['nsentences'], | 'nsentences': sample['nsentences'], | ||||
| @@ -1,7 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import glob | |||||
| import os | import os | ||||
| import os.path as osp | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| @@ -98,8 +96,9 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| trainer = build_trainer(name=Trainers.ofa, default_args=args) | trainer = build_trainer(name=Trainers.ofa, default_args=args) | ||||
| trainer.train() | trainer.train() | ||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| os.listdir(os.path.join(WORKSPACE, 'output'))) | |||||
| self.assertIn( | |||||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | |||||
| shutil.rmtree(WORKSPACE) | shutil.rmtree(WORKSPACE) | ||||