You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Model."""
  16. from collections.abc import Iterable
  17. import os
  18. import numpy as np
  19. from mindspore import log as logger
  20. from ..common.tensor import Tensor
  21. from ..nn.metrics import get_metrics
  22. from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
  23. from .callback import _InternalCallbackParam, RunContext, _CallbackManager
  24. from .. import context
  25. from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
  26. _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
  27. from ..nn.metrics import Loss
  28. from .. import nn
  29. from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
  30. from .parallel_utils import ParallelMode
  31. from ._utils import _to_full_tensor
  32. from ..parallel._utils import _need_to_full
  33. from ..common import dtype as mstype
  34. from .dataset_helper import DatasetHelper
  35. from . import amp
  36. class Model:
  37. """
  38. High-Level API for Training or Testing.
  39. `Model` groups layers into an object with training and inference features.
  40. Args:
  41. network (Cell): The training or testing network.
  42. loss_fn (Cell): Objective function, if loss_fn is None, the
  43. network should contain the logic of loss and grads calculation, and the logic
  44. of parallel if needed. Default: None.
  45. optimizer (Cell): Optimizer for updating the weights. Default: None.
  46. metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
  47. training and testing. eg: {'accuracy', 'recall'}. Default: None.
  48. eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
  49. `eval_network`. Default: None.
  50. eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
  51. `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
  52. elements, representing the positions of loss value, predict value and label, the loss
  53. value would be passed to `Loss` metric, predict value and label would be passed to other
  54. metric. Default: None.
  55. amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
  56. precision training. Supports [O0, O2, O3]. Default: "O0".
  57. - O0: Do not change.
  58. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
  59. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
  60. O2 is recommended on GPU, O3 is recommended on Ascend.
  61. loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
  62. scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
  63. e.g. Use `loss_scale_manager=None` to set the value.
  64. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
  65. Examples:
  66. >>> class Net(nn.Cell):
  67. >>> def __init__(self):
  68. >>> super(Net, self).__init__()
  69. >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  70. >>> self.bn = nn.BatchNorm2d(64)
  71. >>> self.relu = nn.ReLU()
  72. >>> self.flatten = nn.Flatten()
  73. >>> self.fc = nn.Dense(64*224*224, 12) # padding=0
  74. >>>
  75. >>> def construct(self, x):
  76. >>> x = self.conv(x)
  77. >>> x = self.bn(x)
  78. >>> x = self.relu(x)
  79. >>> x = self.flatten(x)
  80. >>> out = self.fc(x)
  81. >>> return out
  82. >>>
  83. >>> net = Net()
  84. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  85. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  86. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  87. >>> dataset = get_dataset()
  88. >>> model.train(2, dataset)
  89. """
  90. def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
  91. eval_indexes=None, amp_level="O0", **kwargs):
  92. self._network = network
  93. self._loss_fn = loss_fn
  94. self._optimizer = optimizer
  95. self._loss_scale_manager = None
  96. self._loss_scale_manager_set = False
  97. self._keep_bn_fp32 = True
  98. self._check_kwargs(kwargs)
  99. self._amp_level = amp_level
  100. self._process_amp_args(kwargs)
  101. self._parallel_mode = _get_parallel_mode()
  102. self._device_number = _get_device_num()
  103. self._global_rank = _get_global_rank()
  104. self._parameter_broadcast = _get_parameter_broadcast()
  105. self._train_network = self._build_train_network()
  106. self._build_eval_network(metrics, eval_network, eval_indexes)
  107. self._build_predict_network()
  108. def _process_amp_args(self, kwargs):
  109. if self._amp_level in ["O0", "O3"]:
  110. self._keep_bn_fp32 = False
  111. if 'keep_batchnorm_fp32' in kwargs:
  112. self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
  113. if 'loss_scale_manager' in kwargs:
  114. self._loss_scale_manager = kwargs['loss_scale_manager']
  115. self._loss_scale_manager_set = True
  116. def _check_kwargs(self, kwargs):
  117. for arg in kwargs:
  118. if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
  119. raise ValueError(f"Unsupport arg '{arg}'")
  120. def _build_train_network(self):
  121. """Build train network"""
  122. network = self._network
  123. if self._optimizer:
  124. if self._loss_scale_manager_set:
  125. network = amp.build_train_network(network,
  126. self._optimizer,
  127. self._loss_fn,
  128. level=self._amp_level,
  129. loss_scale_manager=self._loss_scale_manager,
  130. keep_batchnorm_fp32=self._keep_bn_fp32)
  131. else:
  132. network = amp.build_train_network(network,
  133. self._optimizer,
  134. self._loss_fn,
  135. level=self._amp_level,
  136. keep_batchnorm_fp32=self._keep_bn_fp32)
  137. elif self._loss_fn:
  138. network = nn.WithLossCell(network, self._loss_fn)
  139. # If need to check if loss_fn is not None, but optimizer is None
  140. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  141. network.set_auto_parallel()
  142. return network
  143. def _build_eval_network(self, metrics, eval_network, eval_indexes):
  144. """Build the network for evaluation."""
  145. self._metric_fns = get_metrics(metrics)
  146. if not self._metric_fns:
  147. return
  148. if eval_network is not None:
  149. if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
  150. raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
  151. must be three. But got {}".format(eval_indexes))
  152. self._eval_network = eval_network
  153. self._eval_indexes = eval_indexes
  154. else:
  155. if self._loss_fn is None:
  156. raise ValueError("loss_fn can not be None.")
  157. self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
  158. self._eval_indexes = [0, 1, 2]
  159. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  160. if self._optimizer:
  161. self._eval_network = _VirtualDatasetCell(self._eval_network)
  162. self._eval_network.set_auto_parallel()
  163. def _build_predict_network(self):
  164. """Build the network for prediction."""
  165. self._predict_network = self._network
  166. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  167. self._predict_network = _VirtualDatasetCell(self._network)
  168. self._predict_network.set_auto_parallel()
  169. def _clear_metrics(self):
  170. """Clear metrics local values."""
  171. for metric in self._metric_fns.values():
  172. metric.clear()
  173. def _update_metrics(self, outputs):
  174. """Update metrics local values."""
  175. if not isinstance(outputs, tuple):
  176. raise ValueError("The `outputs` is not tuple.")
  177. if self._eval_indexes is not None and len(outputs) < 3:
  178. raise ValueError("The length of `outputs` must be greater than or equal to 3, \
  179. but got {}".format(len(outputs)))
  180. for metric in self._metric_fns.values():
  181. if self._eval_indexes is None:
  182. metric.update(*outputs)
  183. else:
  184. if isinstance(metric, Loss):
  185. metric.update(outputs[self._eval_indexes[0]])
  186. else:
  187. metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
  188. def _get_metrics(self):
  189. """Get metrics local values."""
  190. metrics = dict()
  191. for key, value in self._metric_fns.items():
  192. metrics[key] = value.eval()
  193. return metrics
  194. def _get_scaling_sens(self):
  195. """get the scaling sens"""
  196. scaling_sens = 1
  197. if self._loss_scale_manager is not None:
  198. scaling_sens = self._loss_scale_manager.get_loss_scale()
  199. if self._parallel_mode == ParallelMode.DATA_PARALLEL:
  200. scaling_sens /= self._device_number
  201. return scaling_sens
  202. def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1):
  203. """Initializes dataset."""
  204. need_wrap = False
  205. if dataset_sink_mode:
  206. # remove later to deal with loop sink
  207. if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
  208. and not context.get_context("enable_ge"):
  209. need_wrap = True
  210. if not is_train:
  211. dataset.__loop_size__ = 1
  212. dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size)
  213. # remove later to deal with loop sink
  214. if need_wrap:
  215. network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
  216. network.set_train(is_train)
  217. network.phase = phase
  218. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  219. network.set_auto_parallel()
  220. return dataset_helper, network
  221. def init(self, train_dataset=None, valid_dataset=None):
  222. """
  223. Initializes compute graphs and data graphs with sink mode.
  224. Note:
  225. Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
  226. Args:
  227. train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
  228. initialized. Default: None.
  229. valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
  230. be initialized, and `metrics` in `Model` can not be None. Default: None.
  231. Examples:
  232. >>> train_dataset = get_train_dataset()
  233. >>> valid_dataset = get_valid_dataset()
  234. >>> net = Net()
  235. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  236. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  237. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
  238. >>> model.init(train_dataset, valid_dataset)
  239. >>> model.train(2, train_dataset)
  240. >>> model.eval(valid_dataset)
  241. """
  242. if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
  243. raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
  244. if not train_dataset and not valid_dataset:
  245. raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
  246. _device_number_check(self._parallel_mode, self._device_number)
  247. if train_dataset:
  248. _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
  249. self._train_network.set_train()
  250. self._train_network.phase = 'train'
  251. if self._parameter_broadcast:
  252. self._train_network.set_broadcast_flag()
  253. train_dataset.__no_send__ = True
  254. train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
  255. is_train=True,
  256. phase='train',
  257. dataset=train_dataset,
  258. dataset_sink_mode=True)
  259. self._train_network = train_network
  260. for inputs in train_dataset_helper:
  261. self._train_network.compile(*inputs)
  262. break
  263. if valid_dataset:
  264. if not self._metric_fns:
  265. raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
  266. self._eval_network.set_train(False)
  267. self._eval_network.phase = 'eval'
  268. valid_dataset.__no_send__ = True
  269. valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
  270. is_train=False,
  271. phase='eval',
  272. dataset=valid_dataset,
  273. dataset_sink_mode=True)
  274. self._eval_network = eval_network
  275. for inputs in valid_dataset_helper:
  276. self._eval_network.compile(*inputs)
  277. break
  278. def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
  279. """
  280. Training.
  281. Args:
  282. epoch (int): Total number of iterations on the data.
  283. train_dataset (Dataset): A training dataset iterator. If there is no
  284. loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
  285. returned and passed to the network. Otherwise, a tuple (data, label) will
  286. be returned, and the data and label are passed to the network and loss
  287. function respectively.
  288. callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
  289. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  290. Configure pynative mode, the training process will be performed with
  291. dataset not sink.
  292. sink_size (int): Control the amount of data each sink. Default: -1.
  293. """
  294. epoch = check_int_positive(epoch)
  295. self._train_network.set_train()
  296. if self._parameter_broadcast:
  297. self._train_network.set_broadcast_flag()
  298. cb_params = _InternalCallbackParam()
  299. cb_params.train_network = self._train_network
  300. cb_params.epoch_num = epoch
  301. if dataset_sink_mode and sink_size > 0:
  302. cb_params.batch_num = sink_size
  303. else:
  304. cb_params.batch_num = train_dataset.get_dataset_size()
  305. cb_params.mode = "train"
  306. cb_params.loss_fn = self._loss_fn
  307. cb_params.optimizer = self._optimizer
  308. cb_params.parallel_mode = self._parallel_mode
  309. cb_params.device_number = self._device_number
  310. cb_params.train_dataset = train_dataset
  311. cb_params.list_callback = self._transform_callbacks(callbacks)
  312. cb_params.train_dataset_element = None
  313. cb_params.network = self._network
  314. ms_role = os.getenv("MS_ROLE")
  315. if ms_role in ("MS_PSERVER", "MS_SCHED"):
  316. epoch = 1
  317. # build callback list
  318. with _CallbackManager(callbacks) as list_callback:
  319. if not dataset_sink_mode:
  320. self._train_process(epoch, train_dataset, list_callback, cb_params)
  321. elif context.get_context("mode") == context.PYNATIVE_MODE:
  322. logger.warning("The pynative mode cannot support dataset sink mode currently."
  323. "So the training process will be performed with dataset not sink.")
  324. self._train_process(epoch, train_dataset, list_callback, cb_params)
  325. else:
  326. self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
  327. @staticmethod
  328. def _transform_callbacks(callbacks):
  329. """Transform callback to a list."""
  330. if callbacks is None:
  331. return []
  332. if isinstance(callbacks, Iterable):
  333. return list(callbacks)
  334. return [callbacks]
  335. def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
  336. """
  337. Training process. The data would be passed to network through dataset channel.
  338. Args:
  339. epoch (int): Total number of iterations on the data.
  340. train_dataset (Dataset): A training dataset iterator. If there is no
  341. loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
  342. returned and passed to the network. Otherwise, a tuple (data, label) should
  343. be returned, and the data and label are passed to the network and loss
  344. function respectively.
  345. list_callback (Callback): Executor of callback list. Default: None.
  346. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  347. sink_size (int): Control the amount of data each sink. Default: -1.
  348. """
  349. dataset_helper, train_network = self._exec_preprocess(self._train_network,
  350. is_train=True,
  351. phase='train',
  352. dataset=train_dataset,
  353. dataset_sink_mode=True,
  354. sink_size=sink_size)
  355. self._train_network = train_network
  356. cb_params.train_network = self._train_network
  357. cb_params.cur_step_num = 0
  358. run_context = RunContext(cb_params)
  359. list_callback.begin(run_context)
  360. # used to stop training for early stop, such as stopAtTIme or stopATStep
  361. should_stop = False
  362. for i in range(epoch):
  363. cb_params.cur_epoch_num = i + 1
  364. list_callback.epoch_begin(run_context)
  365. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
  366. for inputs in dataset_helper:
  367. if _need_to_full():
  368. inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
  369. list_callback.step_begin(run_context)
  370. outputs = self._train_network(*inputs)
  371. cb_params.cur_step_num += dataset_helper.sink_size()
  372. cb_params.net_outputs = outputs
  373. list_callback.step_end(run_context)
  374. list_callback.epoch_end(run_context)
  375. should_stop = should_stop or run_context.get_stop_requested()
  376. if should_stop:
  377. break
  378. dataset_helper.stop_send()
  379. list_callback.end(run_context)
  380. def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
  381. """
  382. Training process. The data would be passed to network directly.
  383. Args:
  384. epoch (int): Total number of iterations on the data.
  385. train_dataset (Dataset): A training dataset iterator. If there is no
  386. loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
  387. returned and passed to the network. Otherwise, a tuple (data, label) should
  388. be returned, and the data and label are passed to the network and loss
  389. function respectively.
  390. list_callback (Callback): Executor of callback list. Default: None.
  391. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  392. """
  393. dataset_helper, _ = self._exec_preprocess(self._train_network,
  394. is_train=True,
  395. phase='train',
  396. dataset=train_dataset,
  397. dataset_sink_mode=False)
  398. cb_params.cur_step_num = 0
  399. run_context = RunContext(cb_params)
  400. list_callback.begin(run_context)
  401. # used to stop training for early stop, such as stopAtTIme or stopATStep
  402. should_stop = False
  403. for i in range(epoch):
  404. cb_params.cur_epoch_num = i + 1
  405. list_callback.epoch_begin(run_context)
  406. for next_element in dataset_helper:
  407. len_element = len(next_element)
  408. if self._loss_fn and len_element != 2:
  409. raise ValueError("when loss_fn is not None, train_dataset should"
  410. "return two elements, but got {}".format(len_element))
  411. cb_params.cur_step_num += 1
  412. list_callback.step_begin(run_context)
  413. overflow = False
  414. if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
  415. scaling_sens = self._get_scaling_sens()
  416. next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
  417. cb_params.train_dataset_element = next_element
  418. outputs = self._train_network(*next_element)
  419. cb_params.net_outputs = outputs
  420. if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
  421. _, overflow, _ = outputs
  422. overflow = np.all(overflow.asnumpy())
  423. self._loss_scale_manager.update_loss_scale(overflow)
  424. list_callback.step_end(run_context)
  425. should_stop = should_stop or run_context.get_stop_requested()
  426. if should_stop:
  427. break
  428. train_dataset.reset()
  429. list_callback.epoch_end(run_context)
  430. should_stop = should_stop or run_context.get_stop_requested()
  431. if should_stop:
  432. break
  433. list_callback.end(run_context)
  434. def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
  435. """
  436. Training API where the iteration is controlled by python front-end.
  437. When setting pynative mode, the training process will be performed with dataset not sink.
  438. Note:
  439. CPU is not supported when dataset_sink_mode is true.
  440. If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
  441. operation in dataset processing. Otherwise, errors could occur since the amount of data
  442. is not the amount training requires.
  443. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
  444. of data will be transferred one by one. The limitation of data transmission per time is 256M.
  445. Args:
  446. epoch (int): Total number of iterations on the data.
  447. train_dataset (Dataset): A training dataset iterator. If there is no
  448. loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
  449. returned and passed to the network. Otherwise, a tuple (data, label) should
  450. be returned, and the data and label are passed to the network and loss
  451. function respectively.
  452. callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
  453. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  454. Configure pynative mode, the training process will be performed with
  455. dataset not sink.
  456. sink_size (int): Control the amount of data each sink.
  457. If sink_size=-1, sink the complete dataset each epoch.
  458. If sink_size>0, sink sink_size data each epoch.
  459. If dataset_sink_mode is False, set sink_size invalid. Default: -1.
  460. Examples:
  461. >>> dataset = get_dataset()
  462. >>> net = Net()
  463. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  464. >>> loss_scale_manager = FixedLossScaleManager()
  465. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  466. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  467. >>> model.train(2, dataset)
  468. """
  469. check_bool(dataset_sink_mode)
  470. check_int(sink_size)
  471. if sink_size < -1 or sink_size == 0:
  472. raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
  473. _device_number_check(self._parallel_mode, self._device_number)
  474. _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
  475. self._train(epoch,
  476. train_dataset,
  477. callbacks=callbacks,
  478. dataset_sink_mode=dataset_sink_mode,
  479. sink_size=sink_size)
  480. def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
  481. """
  482. Evaluation. The data would be passed to network through dataset channel.
  483. Args:
  484. valid_dataset (Dataset): Dataset to evaluate the model.
  485. list_callback (Callback): Executor of callback list. Default: None.
  486. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  487. Returns:
  488. Dict, returns the loss value & metrics values for the model in test mode.
  489. """
  490. run_context = RunContext(cb_params)
  491. dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
  492. is_train=False,
  493. phase='eval',
  494. dataset=valid_dataset,
  495. dataset_sink_mode=True)
  496. self._eval_network = eval_network
  497. cb_params.eval_network = self._eval_network
  498. list_callback.begin(run_context)
  499. for inputs in dataset_helper:
  500. cb_params.cur_step_num += 1
  501. list_callback.step_begin(run_context)
  502. outputs = self._eval_network(*inputs)
  503. cb_params.net_outputs = outputs
  504. list_callback.step_end(run_context)
  505. self._update_metrics(outputs)
  506. metrics = self._get_metrics()
  507. cb_params.metrics = metrics
  508. list_callback.end(run_context)
  509. return metrics
  510. def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
  511. """
  512. Evaluation. The data would be passed to network directly.
  513. Args:
  514. valid_dataset (Dataset): Dataset to evaluate the model.
  515. list_callback (Callback): Executor of callback list. Default: None.
  516. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  517. Returns:
  518. Dict, returns the loss value & metrics values for the model in test mode.
  519. """
  520. run_context = RunContext(cb_params)
  521. list_callback.begin(run_context)
  522. dataset_helper, _ = self._exec_preprocess(self._eval_network,
  523. is_train=False,
  524. phase='eval',
  525. dataset=valid_dataset,
  526. dataset_sink_mode=False)
  527. for next_element in dataset_helper:
  528. cb_params.cur_step_num += 1
  529. list_callback.step_begin(run_context)
  530. outputs = self._eval_network(*next_element)
  531. cb_params.net_outputs = outputs
  532. list_callback.step_end(run_context)
  533. self._update_metrics(outputs)
  534. metrics = self._get_metrics()
  535. cb_params.metrics = metrics
  536. list_callback.end(run_context)
  537. return metrics
  538. def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
  539. """
  540. Evaluation API where the iteration is controlled by python front-end.
  541. Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
  542. Note:
  543. CPU is not supported when dataset_sink_mode is true.
  544. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
  545. of data will be transferred one by one. The limitation of data transmission per time is 256M.
  546. Args:
  547. valid_dataset (Dataset): Dataset to evaluate the model.
  548. callbacks (list): List of callback object. Callbacks which should be excuted
  549. while training. Default: None.
  550. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  551. Returns:
  552. Dict, returns the loss value & metrics values for the model in test mode.
  553. Examples:
  554. >>> dataset = get_dataset()
  555. >>> net = Net()
  556. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  557. >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
  558. >>> model.eval(dataset)
  559. """
  560. check_bool(dataset_sink_mode)
  561. _device_number_check(self._parallel_mode, self._device_number)
  562. if not self._metric_fns:
  563. raise ValueError("metric fn can not be None or empty.")
  564. cb_params = _InternalCallbackParam()
  565. cb_params.eval_network = self._eval_network
  566. cb_params.valid_dataset = valid_dataset
  567. cb_params.batch_num = valid_dataset.get_dataset_size()
  568. cb_params.mode = "eval"
  569. cb_params.cur_step_num = 0
  570. cb_params.list_callback = self._transform_callbacks(callbacks)
  571. cb_params.network = self._network
  572. self._eval_network.set_train(mode=False)
  573. self._eval_network.phase = 'eval'
  574. self._clear_metrics()
  575. with _CallbackManager(callbacks) as list_callback:
  576. if dataset_sink_mode:
  577. return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
  578. return self._eval_process(valid_dataset, list_callback, cb_params)
  579. def predict(self, *predict_data):
  580. """
  581. Generates output predictions for the input samples.
  582. Data could be single tensor, or list of tensor, tuple of tensor.
  583. Note:
  584. Batch data should be put together in one tensor.
  585. Args:
  586. predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
  587. Returns:
  588. Tensor, array(s) of predictions.
  589. Examples:
  590. >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  591. >>> model = Model(Net())
  592. >>> model.predict(input_data)
  593. """
  594. self._predict_network.set_train(False)
  595. check_input_data(*predict_data, data_class=Tensor)
  596. result = self._predict_network(*predict_data)
  597. check_output_data(result)
  598. return result
  599. __all__ = ["Model"]