|
|
@@ -267,7 +267,7 @@ class Model: |
|
|
|
|
|
|
|
|
return dataset_helper, network |
|
|
return dataset_helper, network |
|
|
|
|
|
|
|
|
def init(self, train_dataset=None, valid_dataset=None): |
|
|
|
|
|
|
|
|
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1): |
|
|
""" |
|
|
""" |
|
|
Initialize compute graphs and data graphs with the sink mode. |
|
|
Initialize compute graphs and data graphs with the sink mode. |
|
|
|
|
|
|
|
|
@@ -279,17 +279,7 @@ class Model: |
|
|
initialized. Default: None. |
|
|
initialized. Default: None. |
|
|
valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs |
|
|
valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs |
|
|
will be initialized, and `metrics` in `Model` can not be None. Default: None. |
|
|
will be initialized, and `metrics` in `Model` can not be None. Default: None. |
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> train_dataset = get_train_dataset() |
|
|
|
|
|
>>> valid_dataset = get_valid_dataset() |
|
|
|
|
|
>>> net = Net() |
|
|
|
|
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) |
|
|
|
|
|
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) |
|
|
|
|
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) |
|
|
|
|
|
>>> model.init(train_dataset, valid_dataset) |
|
|
|
|
|
>>> model.train(2, train_dataset) |
|
|
|
|
|
>>> model.eval(valid_dataset) |
|
|
|
|
|
|
|
|
sink_size (int): Control the amount of data in each sink. Default: -1. |
|
|
""" |
|
|
""" |
|
|
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": |
|
|
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": |
|
|
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') |
|
|
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') |
|
|
@@ -309,7 +299,8 @@ class Model: |
|
|
is_train=True, |
|
|
is_train=True, |
|
|
phase='train', |
|
|
phase='train', |
|
|
dataset=train_dataset, |
|
|
dataset=train_dataset, |
|
|
dataset_sink_mode=True) |
|
|
|
|
|
|
|
|
dataset_sink_mode=True, |
|
|
|
|
|
sink_size=sink_size) |
|
|
self._train_network = train_network |
|
|
self._train_network = train_network |
|
|
for inputs in train_dataset_helper: |
|
|
for inputs in train_dataset_helper: |
|
|
self._train_network.compile(*inputs) |
|
|
self._train_network.compile(*inputs) |
|
|
|