From 2f5a454ef490da4aa4d2f3c9869f94ca31a38f12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=98=89=E7=90=AA?= Date: Wed, 26 Aug 2020 15:57:17 +0800 Subject: [PATCH] transfer_tensor_to_tuple --- mindspore/train/model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 324b09b453..ea2fbc6672 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -386,6 +386,15 @@ class Model: return [callbacks] + def _transfer_tensor_to_tuple(self, inputs): + """ + If the input is a tensor, convert it to a tuple. If not, the output is unchanged. + """ + if isinstance(inputs, Tensor): + return (inputs,) + + return inputs + def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): """ Training process. The data would be passed to network through dataset channel. @@ -428,6 +437,7 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: + inputs = self._transfer_tensor_to_tuple(inputs) if _need_to_full() and context.get_context("device_target") == "GPU": inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) @@ -475,6 +485,7 @@ class Model: list_callback.epoch_begin(run_context) for next_element in dataset_helper: + next_element = self._transfer_tensor_to_tuple(next_element) len_element = len(next_element) if self._loss_fn and len_element != 2: raise ValueError("when loss_fn is not None, train_dataset should" @@ -592,6 +603,7 @@ class Model: list_callback.begin(run_context) for inputs in dataset_helper: + inputs = self._transfer_tensor_to_tuple(inputs) cb_params.cur_step_num += 1 list_callback.step_begin(run_context) @@ -630,6 +642,7 @@ class Model: for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) + next_element = self._transfer_tensor_to_tuple(next_element) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs list_callback.step_end(run_context)