From 0e9815f63cebc81f1b016646bbd144bb17455afd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=98=89=E7=90=AA?= Date: Thu, 27 Aug 2020 16:10:32 +0800 Subject: [PATCH] modify 0.7 --- mindspore/train/model.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index ea2fbc6672..50a5d8eacf 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper from . import amp +def _transfer_tensor_to_tuple(inputs): + """ + If the input is a tensor, convert it to a tuple. If not, the output is unchanged. + """ + if isinstance(inputs, Tensor): + return (inputs,) + + return inputs + + class Model: """ High-Level API for Training or Testing. @@ -386,15 +396,6 @@ class Model: return [callbacks] - def _transfer_tensor_to_tuple(self, inputs): - """ - If the input is a tensor, convert it to a tuple. If not, the output is unchanged. - """ - if isinstance(inputs, Tensor): - return (inputs,) - - return inputs - def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): """ Training process. The data would be passed to network through dataset channel. @@ -437,7 +438,6 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: - inputs = self._transfer_tensor_to_tuple(inputs) if _need_to_full() and context.get_context("device_target") == "GPU": inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) @@ -485,7 +485,7 @@ class Model: list_callback.epoch_begin(run_context) for next_element in dataset_helper: - next_element = self._transfer_tensor_to_tuple(next_element) + next_element = _transfer_tensor_to_tuple(next_element) len_element = len(next_element) if self._loss_fn and len_element != 2: raise ValueError("when loss_fn is not None, train_dataset should" @@ -603,7 +603,6 @@ class Model: list_callback.begin(run_context) for inputs in dataset_helper: - inputs = self._transfer_tensor_to_tuple(inputs) cb_params.cur_step_num += 1 list_callback.step_begin(run_context) @@ -642,7 +641,7 @@ class Model: for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) - next_element = self._transfer_tensor_to_tuple(next_element) + next_element = _transfer_tensor_to_tuple(next_element) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs list_callback.step_end(run_context)