|
|
|
@@ -20,6 +20,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_fn |
|
|
|
from modelscope.trainers import EpochBasedTrainer |
|
|
|
from modelscope.trainers.builder import TRAINERS |
|
|
|
from modelscope.trainers.optimizer.builder import build_optimizer |
|
|
|
from modelscope.trainers.parallel.utils import is_parallel |
|
|
|
from modelscope.utils.config import Config |
|
|
|
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, |
|
|
|
ModeKeys) |
|
|
|
@@ -137,6 +138,7 @@ class OFATrainer(EpochBasedTrainer): |
|
|
|
return cfg |
|
|
|
|
|
|
|
def train_step(self, model, inputs): |
|
|
|
model = model.module if self._dist or is_parallel(model) else model |
|
|
|
model.train() |
|
|
|
loss, sample_size, logging_output = self.criterion(model, inputs) |
|
|
|
train_outputs = {'loss': loss} |
|
|
|
|