Browse Source

modify 0.7

tags/v0.7.0-beta
李嘉琪 5 years ago
parent
commit
0e9815f63c
1 changed files with 12 additions and 13 deletions
  1. +12
    -13
      mindspore/train/model.py

+ 12
- 13
mindspore/train/model.py View File

@@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper
from . import amp from . import amp




def _transfer_tensor_to_tuple(inputs):
"""
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
"""
if isinstance(inputs, Tensor):
return (inputs,)

return inputs


class Model: class Model:
""" """
High-Level API for Training or Testing. High-Level API for Training or Testing.
@@ -386,15 +396,6 @@ class Model:


return [callbacks] return [callbacks]


def _transfer_tensor_to_tuple(self, inputs):
"""
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
"""
if isinstance(inputs, Tensor):
return (inputs,)

return inputs

def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
""" """
Training process. The data would be passed to network through dataset channel. Training process. The data would be passed to network through dataset channel.
@@ -437,7 +438,6 @@ class Model:


# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
inputs = self._transfer_tensor_to_tuple(inputs)
if _need_to_full() and context.get_context("device_target") == "GPU": if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
@@ -485,7 +485,7 @@ class Model:
list_callback.epoch_begin(run_context) list_callback.epoch_begin(run_context)


for next_element in dataset_helper: for next_element in dataset_helper:
next_element = self._transfer_tensor_to_tuple(next_element)
next_element = _transfer_tensor_to_tuple(next_element)
len_element = len(next_element) len_element = len(next_element)
if self._loss_fn and len_element != 2: if self._loss_fn and len_element != 2:
raise ValueError("when loss_fn is not None, train_dataset should" raise ValueError("when loss_fn is not None, train_dataset should"
@@ -603,7 +603,6 @@ class Model:
list_callback.begin(run_context) list_callback.begin(run_context)


for inputs in dataset_helper: for inputs in dataset_helper:
inputs = self._transfer_tensor_to_tuple(inputs)
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
list_callback.step_begin(run_context) list_callback.step_begin(run_context)


@@ -642,7 +641,7 @@ class Model:
for next_element in dataset_helper: for next_element in dataset_helper:
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
next_element = self._transfer_tensor_to_tuple(next_element)
next_element = _transfer_tensor_to_tuple(next_element)
outputs = self._eval_network(*next_element) outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
list_callback.step_end(run_context) list_callback.step_end(run_context)


Loading…
Cancel
Save