| @@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper | |||||
| from . import amp | from . import amp | ||||
| def _transfer_tensor_to_tuple(inputs): | |||||
| """ | |||||
| If the input is a tensor, convert it to a tuple. If not, the output is unchanged. | |||||
| """ | |||||
| if isinstance(inputs, Tensor): | |||||
| return (inputs,) | |||||
| return inputs | |||||
| class Model: | class Model: | ||||
| """ | """ | ||||
| High-Level API for Training or Testing. | High-Level API for Training or Testing. | ||||
| @@ -386,15 +396,6 @@ class Model: | |||||
| return [callbacks] | return [callbacks] | ||||
| def _transfer_tensor_to_tuple(self, inputs): | |||||
| """ | |||||
| If the input is a tensor, convert it to a tuple. If not, the output is unchanged. | |||||
| """ | |||||
| if isinstance(inputs, Tensor): | |||||
| return (inputs,) | |||||
| return inputs | |||||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): | def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): | ||||
| """ | """ | ||||
| Training process. The data would be passed to network through dataset channel. | Training process. The data would be passed to network through dataset channel. | ||||
| @@ -437,7 +438,6 @@ class Model: | |||||
| # for data sink dataset_helper only iter once, other wise iter epoch_size times. | # for data sink dataset_helper only iter once, other wise iter epoch_size times. | ||||
| for inputs in dataset_helper: | for inputs in dataset_helper: | ||||
| inputs = self._transfer_tensor_to_tuple(inputs) | |||||
| if _need_to_full() and context.get_context("device_target") == "GPU": | if _need_to_full() and context.get_context("device_target") == "GPU": | ||||
| inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) | inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) | ||||
| list_callback.step_begin(run_context) | list_callback.step_begin(run_context) | ||||
| @@ -485,7 +485,7 @@ class Model: | |||||
| list_callback.epoch_begin(run_context) | list_callback.epoch_begin(run_context) | ||||
| for next_element in dataset_helper: | for next_element in dataset_helper: | ||||
| next_element = self._transfer_tensor_to_tuple(next_element) | |||||
| next_element = _transfer_tensor_to_tuple(next_element) | |||||
| len_element = len(next_element) | len_element = len(next_element) | ||||
| if self._loss_fn and len_element != 2: | if self._loss_fn and len_element != 2: | ||||
| raise ValueError("when loss_fn is not None, train_dataset should" | raise ValueError("when loss_fn is not None, train_dataset should" | ||||
| @@ -603,7 +603,6 @@ class Model: | |||||
| list_callback.begin(run_context) | list_callback.begin(run_context) | ||||
| for inputs in dataset_helper: | for inputs in dataset_helper: | ||||
| inputs = self._transfer_tensor_to_tuple(inputs) | |||||
| cb_params.cur_step_num += 1 | cb_params.cur_step_num += 1 | ||||
| list_callback.step_begin(run_context) | list_callback.step_begin(run_context) | ||||
| @@ -642,7 +641,7 @@ class Model: | |||||
| for next_element in dataset_helper: | for next_element in dataset_helper: | ||||
| cb_params.cur_step_num += 1 | cb_params.cur_step_num += 1 | ||||
| list_callback.step_begin(run_context) | list_callback.step_begin(run_context) | ||||
| next_element = self._transfer_tensor_to_tuple(next_element) | |||||
| next_element = _transfer_tensor_to_tuple(next_element) | |||||
| outputs = self._eval_network(*next_element) | outputs = self._eval_network(*next_element) | ||||
| cb_params.net_outputs = outputs | cb_params.net_outputs = outputs | ||||
| list_callback.step_end(run_context) | list_callback.step_end(run_context) | ||||