|
|
|
@@ -108,6 +108,7 @@ class Model: |
|
|
|
|
|
|
|
self._train_network = self._build_train_network() |
|
|
|
self._build_eval_network(metrics, eval_network, eval_indexes) |
|
|
|
self._build_predict_network() |
|
|
|
|
|
|
|
def _check_kwargs(self, kwargs): |
|
|
|
for arg in kwargs: |
|
|
|
@@ -153,6 +154,12 @@ class Model: |
|
|
|
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) |
|
|
|
self._eval_indexes = [0, 1, 2] |
|
|
|
|
|
|
|
def _build_predict_network(self): |
|
|
|
"""Build the network for prediction.""" |
|
|
|
self._predict_network = self._network |
|
|
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): |
|
|
|
self._predict_network = _VirtualDatasetCell(self._network) |
|
|
|
|
|
|
|
def _clear_metrics(self): |
|
|
|
"""Clear metrics local values.""" |
|
|
|
for metric in self._metric_fns.values(): |
|
|
|
@@ -470,6 +477,7 @@ class Model: |
|
|
|
|
|
|
|
dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) |
|
|
|
for next_element in dataset_helper: |
|
|
|
cb_params.cur_step_num += 1 |
|
|
|
list_callback.step_begin(run_context) |
|
|
|
outputs = self._eval_network(*next_element) |
|
|
|
cb_params.net_outputs = outputs |
|
|
|
@@ -549,12 +557,9 @@ class Model: |
|
|
|
>>> model = Model(Net()) |
|
|
|
>>> model.predict(input_data) |
|
|
|
""" |
|
|
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): |
|
|
|
self._network = _VirtualDatasetCell(self._network) |
|
|
|
|
|
|
|
self._network.set_train(False) |
|
|
|
self._predict_network.set_train(False) |
|
|
|
check_input_data(*predict_data, data_class=Tensor) |
|
|
|
result = self._network(*predict_data) |
|
|
|
result = self._predict_network(*predict_data) |
|
|
|
|
|
|
|
check_output_data(result) |
|
|
|
return result |
|
|
|
|