|
|
|
@@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper |
|
|
|
from . import amp |
|
|
|
|
|
|
|
|
|
|
|
def _transfer_tensor_to_tuple(inputs): |
|
|
|
""" |
|
|
|
If the input is a tensor, convert it to a tuple. If not, the output is unchanged. |
|
|
|
""" |
|
|
|
if isinstance(inputs, Tensor): |
|
|
|
return (inputs,) |
|
|
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
|
|
|
class Model: |
|
|
|
""" |
|
|
|
High-Level API for Training or Testing. |
|
|
|
@@ -386,15 +396,6 @@ class Model: |
|
|
|
|
|
|
|
return [callbacks] |
|
|
|
|
|
|
|
def _transfer_tensor_to_tuple(self, inputs): |
|
|
|
""" |
|
|
|
If the input is a tensor, convert it to a tuple. If not, the output is unchanged. |
|
|
|
""" |
|
|
|
if isinstance(inputs, Tensor): |
|
|
|
return (inputs,) |
|
|
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): |
|
|
|
""" |
|
|
|
Training process. The data would be passed to network through dataset channel. |
|
|
|
@@ -437,7 +438,6 @@ class Model: |
|
|
|
|
|
|
|
# for data sink dataset_helper only iter once, other wise iter epoch_size times. |
|
|
|
for inputs in dataset_helper: |
|
|
|
inputs = self._transfer_tensor_to_tuple(inputs) |
|
|
|
if _need_to_full() and context.get_context("device_target") == "GPU": |
|
|
|
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) |
|
|
|
list_callback.step_begin(run_context) |
|
|
|
@@ -485,7 +485,7 @@ class Model: |
|
|
|
list_callback.epoch_begin(run_context) |
|
|
|
|
|
|
|
for next_element in dataset_helper: |
|
|
|
next_element = self._transfer_tensor_to_tuple(next_element) |
|
|
|
next_element = _transfer_tensor_to_tuple(next_element) |
|
|
|
len_element = len(next_element) |
|
|
|
if self._loss_fn and len_element != 2: |
|
|
|
raise ValueError("when loss_fn is not None, train_dataset should" |
|
|
|
@@ -603,7 +603,6 @@ class Model: |
|
|
|
list_callback.begin(run_context) |
|
|
|
|
|
|
|
for inputs in dataset_helper: |
|
|
|
inputs = self._transfer_tensor_to_tuple(inputs) |
|
|
|
cb_params.cur_step_num += 1 |
|
|
|
list_callback.step_begin(run_context) |
|
|
|
|
|
|
|
@@ -642,7 +641,7 @@ class Model: |
|
|
|
for next_element in dataset_helper: |
|
|
|
cb_params.cur_step_num += 1 |
|
|
|
list_callback.step_begin(run_context) |
|
|
|
next_element = self._transfer_tensor_to_tuple(next_element) |
|
|
|
next_element = _transfer_tensor_to_tuple(next_element) |
|
|
|
outputs = self._eval_network(*next_element) |
|
|
|
cb_params.net_outputs = outputs |
|
|
|
list_callback.step_end(run_context) |
|
|
|
|