|
|
|
@@ -217,6 +217,94 @@ class Model: |
|
|
|
scaling_sens /= self._device_number |
|
|
|
return scaling_sens |
|
|
|
|
|
|
|
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode): |
|
|
|
"""Initializes dataset.""" |
|
|
|
need_wrap = False |
|
|
|
if dataset_sink_mode: |
|
|
|
# remove later to deal with loop sink |
|
|
|
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ |
|
|
|
and not context.get_context("enable_ge"): |
|
|
|
need_wrap = True |
|
|
|
|
|
|
|
if not is_train: |
|
|
|
dataset.__loop_size__ = 1 |
|
|
|
|
|
|
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode) |
|
|
|
|
|
|
|
# remove later to deal with loop sink |
|
|
|
if need_wrap: |
|
|
|
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) |
|
|
|
network.set_train(is_train) |
|
|
|
network.phase = phase |
|
|
|
|
|
|
|
return dataset_helper, network |
|
|
|
|
|
|
|
def init(self, train_dataset=None, valid_dataset=None): |
|
|
|
""" |
|
|
|
Initializes compute graphs and data graphs with sink mode. |
|
|
|
|
|
|
|
Note: |
|
|
|
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. |
|
|
|
|
|
|
|
Args: |
|
|
|
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be |
|
|
|
initialized. Default: None. |
|
|
|
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will |
|
|
|
be initialized, and `metrics` in `Model` can not be None. Default: None. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> train_dataset = get_train_dataset() |
|
|
|
>>> valid_dataset = get_valid_dataset() |
|
|
|
>>> net = Net() |
|
|
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) |
|
|
|
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) |
|
|
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) |
|
|
|
>>> model.init(train_dataset, valid_dataset) |
|
|
|
>>> model.train(2, train_dataset) |
|
|
|
>>> model.eval(valid_dataset) |
|
|
|
""" |
|
|
|
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": |
|
|
|
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') |
|
|
|
|
|
|
|
if not train_dataset and not valid_dataset: |
|
|
|
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') |
|
|
|
|
|
|
|
_device_number_check(self._parallel_mode, self._device_number) |
|
|
|
|
|
|
|
if train_dataset: |
|
|
|
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) |
|
|
|
self._train_network.set_train() |
|
|
|
self._train_network.phase = 'train' |
|
|
|
|
|
|
|
if self._parameter_broadcast: |
|
|
|
self._train_network.set_broadcast_flag() |
|
|
|
|
|
|
|
train_dataset_helper, train_network = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
phase='train', |
|
|
|
dataset=train_dataset, |
|
|
|
dataset_sink_mode=True) |
|
|
|
self._train_network = train_network |
|
|
|
for inputs in train_dataset_helper: |
|
|
|
self._train_network.compile(*inputs) |
|
|
|
break |
|
|
|
|
|
|
|
if valid_dataset: |
|
|
|
if not self._metric_fns: |
|
|
|
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') |
|
|
|
|
|
|
|
self._eval_network.set_train(False) |
|
|
|
self._eval_network.phase = 'eval' |
|
|
|
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, |
|
|
|
is_train=False, |
|
|
|
phase='eval', |
|
|
|
dataset=valid_dataset, |
|
|
|
dataset_sink_mode=True) |
|
|
|
self._eval_network = eval_network |
|
|
|
for inputs in valid_dataset_helper: |
|
|
|
self._eval_network.compile(*inputs) |
|
|
|
break |
|
|
|
|
|
|
|
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): |
|
|
|
""" |
|
|
|
Training. |
|
|
|
@@ -277,21 +365,15 @@ class Model: |
|
|
|
list_callback (_ListCallback): Executor of callback list. Default: None. |
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None. |
|
|
|
""" |
|
|
|
# remove later to deal with loop sink |
|
|
|
need_wrap = False |
|
|
|
if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ |
|
|
|
and not context.get_context("enable_ge"): |
|
|
|
need_wrap = True |
|
|
|
|
|
|
|
dataset_helper = DatasetHelper(train_dataset) |
|
|
|
# remove later to deal with loop sink |
|
|
|
if need_wrap: |
|
|
|
self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), |
|
|
|
train_dataset.__ME_INITED__) |
|
|
|
cb_params.train_network = self._train_network |
|
|
|
self._train_network.set_train() |
|
|
|
|
|
|
|
dataset_helper, train_network = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
phase='train', |
|
|
|
dataset=train_dataset, |
|
|
|
dataset_sink_mode=True) |
|
|
|
self._train_network = train_network |
|
|
|
cb_params.train_network = self._train_network |
|
|
|
cb_params.cur_step_num = 0 |
|
|
|
|
|
|
|
loop_size = dataset_helper.loop_size() |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
list_callback.begin(run_context) |
|
|
|
@@ -331,7 +413,11 @@ class Model: |
|
|
|
list_callback (_ListCallback): Executor of callback list. Default: None. |
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None. |
|
|
|
""" |
|
|
|
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) |
|
|
|
dataset_helper, _ = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
phase='train', |
|
|
|
dataset=train_dataset, |
|
|
|
dataset_sink_mode=False) |
|
|
|
cb_params.cur_step_num = 0 |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
list_callback.begin(run_context) |
|
|
|
@@ -437,26 +523,15 @@ class Model: |
|
|
|
Returns: |
|
|
|
Dict, returns the loss value & metrics values for the model in test mode. |
|
|
|
""" |
|
|
|
_device_number_check(self._parallel_mode, self._device_number) |
|
|
|
|
|
|
|
run_context = RunContext(cb_params) |
|
|
|
|
|
|
|
# remove later to deal with loop sink |
|
|
|
need_wrap = False |
|
|
|
if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ |
|
|
|
and not context.get_context("enable_ge"): |
|
|
|
need_wrap = True |
|
|
|
|
|
|
|
valid_dataset.__loop_size__ = 1 |
|
|
|
dataset_helper = DatasetHelper(valid_dataset) |
|
|
|
|
|
|
|
# remove later to deal with loop sink |
|
|
|
if need_wrap: |
|
|
|
self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), |
|
|
|
valid_dataset.__ME_INITED__) |
|
|
|
self._eval_network.set_train(mode=False) |
|
|
|
self._eval_network.phase = 'eval' |
|
|
|
|
|
|
|
dataset_helper, eval_network = self._exec_preprocess(self._eval_network, |
|
|
|
is_train=False, |
|
|
|
phase='eval', |
|
|
|
dataset=valid_dataset, |
|
|
|
dataset_sink_mode=True) |
|
|
|
self._eval_network = eval_network |
|
|
|
cb_params.eval_network = self._eval_network |
|
|
|
list_callback.begin(run_context) |
|
|
|
|
|
|
|
for inputs in dataset_helper: |
|
|
|
@@ -490,7 +565,11 @@ class Model: |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
list_callback.begin(run_context) |
|
|
|
|
|
|
|
dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) |
|
|
|
dataset_helper, _ = self._exec_preprocess(self._eval_network, |
|
|
|
is_train=False, |
|
|
|
phase='eval', |
|
|
|
dataset=valid_dataset, |
|
|
|
dataset_sink_mode=False) |
|
|
|
for next_element in dataset_helper: |
|
|
|
cb_params.cur_step_num += 1 |
|
|
|
list_callback.step_begin(run_context) |
|
|
|
@@ -532,6 +611,7 @@ class Model: |
|
|
|
>>> model.eval(dataset) |
|
|
|
""" |
|
|
|
check_bool(dataset_sink_mode) |
|
|
|
_device_number_check(self._parallel_mode, self._device_number) |
|
|
|
if not self._metric_fns: |
|
|
|
raise ValueError("metric fn can not be None or empty.") |
|
|
|
|
|
|
|
|