You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 35 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Model."""
  16. from collections.abc import Iterable
  17. import os
  18. import math
  19. import numpy as np
  20. from mindspore import log as logger
  21. from ..common.tensor import Tensor
  22. from ..nn.metrics import get_metrics
  23. from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
  24. from .callback import _InternalCallbackParam, RunContext, _CallbackManager
  25. from .. import context
  26. from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
  27. _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
  28. from ..nn.metrics import Loss
  29. from .. import nn
  30. from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
  31. from .parallel_utils import ParallelMode
  32. from ..parallel._utils import _need_to_full, _to_full_tensor
  33. from ..common import dtype as mstype
  34. from .dataset_helper import DatasetHelper
  35. from . import amp
  36. class Model:
  37. """
  38. High-Level API for Training or Testing.
  39. `Model` groups layers into an object with training and inference features.
  40. Args:
  41. network (Cell): A training or testing network.
  42. loss_fn (Cell): Objective function, if loss_fn is None, the
  43. network should contain the logic of loss and grads calculation, and the logic
  44. of parallel if needed. Default: None.
  45. optimizer (Cell): Optimizer for updating the weights. Default: None.
  46. metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
  47. training and testing. eg: {'accuracy', 'recall'}. Default: None.
  48. eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
  49. `eval_network`. Default: None.
  50. eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
  51. `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
  52. elements, including the positions of loss value, predicted value and label. The loss
  53. value would be passed to the `Loss` metric, the predicted value and label would be passed
  54. to other metric. Default: None.
  55. amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
  56. precision training. Supports [O0, O2, O3]. Default: "O0".
  57. - O0: Do not change.
  58. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
  59. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
  60. O2 is recommended on GPU, O3 is recommended on Ascend.
  61. loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
  62. scale the loss by LossScaleManager. It is a key argument.
  63. e.g. Use `loss_scale_manager=None` to set the value.
  64. keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
  65. will be overwritten. Default: True.
  66. Examples:
  67. >>> class Net(nn.Cell):
  68. >>> def __init__(self):
  69. >>> super(Net, self).__init__()
  70. >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  71. >>> self.bn = nn.BatchNorm2d(64)
  72. >>> self.relu = nn.ReLU()
  73. >>> self.flatten = nn.Flatten()
  74. >>> self.fc = nn.Dense(64*224*224, 12) # padding=0
  75. >>>
  76. >>> def construct(self, x):
  77. >>> x = self.conv(x)
  78. >>> x = self.bn(x)
  79. >>> x = self.relu(x)
  80. >>> x = self.flatten(x)
  81. >>> out = self.fc(x)
  82. >>> return out
  83. >>>
  84. >>> net = Net()
  85. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  86. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  87. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  88. >>> dataset = get_dataset()
  89. >>> model.train(2, dataset)
  90. """
  91. def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
  92. eval_indexes=None, amp_level="O0", **kwargs):
  93. self._network = network
  94. self._loss_fn = loss_fn
  95. self._optimizer = optimizer
  96. self._loss_scale_manager = None
  97. self._loss_scale_manager_set = False
  98. self._keep_bn_fp32 = True
  99. self._check_kwargs(kwargs)
  100. self._amp_level = amp_level
  101. self._process_amp_args(kwargs)
  102. self._parallel_mode = _get_parallel_mode()
  103. self._device_number = _get_device_num()
  104. self._global_rank = _get_global_rank()
  105. self._parameter_broadcast = _get_parameter_broadcast()
  106. self._train_network = self._build_train_network()
  107. self._build_eval_network(metrics, eval_network, eval_indexes)
  108. self._build_predict_network()
  109. def _process_amp_args(self, kwargs):
  110. if self._amp_level in ["O0", "O3"]:
  111. self._keep_bn_fp32 = False
  112. if 'keep_batchnorm_fp32' in kwargs:
  113. self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
  114. if 'loss_scale_manager' in kwargs:
  115. self._loss_scale_manager = kwargs['loss_scale_manager']
  116. self._loss_scale_manager_set = True
  117. def _check_kwargs(self, kwargs):
  118. for arg in kwargs:
  119. if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
  120. raise ValueError(f"Unsupported arg '{arg}'")
  121. def _build_train_network(self):
  122. """Build train network"""
  123. network = self._network
  124. if self._optimizer:
  125. if self._loss_scale_manager_set:
  126. network = amp.build_train_network(network,
  127. self._optimizer,
  128. self._loss_fn,
  129. level=self._amp_level,
  130. loss_scale_manager=self._loss_scale_manager,
  131. keep_batchnorm_fp32=self._keep_bn_fp32)
  132. else:
  133. network = amp.build_train_network(network,
  134. self._optimizer,
  135. self._loss_fn,
  136. level=self._amp_level,
  137. keep_batchnorm_fp32=self._keep_bn_fp32)
  138. elif self._loss_fn:
  139. network = nn.WithLossCell(network, self._loss_fn)
  140. # If need to check if loss_fn is not None, but optimizer is None
  141. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  142. network.set_auto_parallel()
  143. return network
  144. def _build_eval_network(self, metrics, eval_network, eval_indexes):
  145. """Build the network for evaluation."""
  146. self._metric_fns = get_metrics(metrics)
  147. if not self._metric_fns:
  148. return
  149. if eval_network is not None:
  150. if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
  151. raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
  152. must be three. But got {}".format(eval_indexes))
  153. self._eval_network = eval_network
  154. self._eval_indexes = eval_indexes
  155. else:
  156. if self._loss_fn is None:
  157. raise ValueError("loss_fn can not be None.")
  158. self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"])
  159. self._eval_indexes = [0, 1, 2]
  160. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  161. if self._optimizer:
  162. self._eval_network = _VirtualDatasetCell(self._eval_network)
  163. self._eval_network.set_auto_parallel()
  164. def _build_predict_network(self):
  165. """Build the network for prediction."""
  166. self._predict_network = self._network
  167. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  168. self._predict_network = _VirtualDatasetCell(self._network)
  169. self._predict_network.set_auto_parallel()
  170. def _clear_metrics(self):
  171. """Clear metrics local values."""
  172. for metric in self._metric_fns.values():
  173. metric.clear()
  174. def _update_metrics(self, outputs):
  175. """Update metrics local values."""
  176. if not isinstance(outputs, tuple):
  177. raise ValueError("The `outputs` is not tuple.")
  178. if self._eval_indexes is not None and len(outputs) < 3:
  179. raise ValueError("The length of `outputs` must be greater than or equal to 3, \
  180. but got {}".format(len(outputs)))
  181. for metric in self._metric_fns.values():
  182. if self._eval_indexes is None:
  183. metric.update(*outputs)
  184. else:
  185. if isinstance(metric, Loss):
  186. metric.update(outputs[self._eval_indexes[0]])
  187. else:
  188. metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
  189. def _get_metrics(self):
  190. """Get metrics local values."""
  191. metrics = dict()
  192. for key, value in self._metric_fns.items():
  193. metrics[key] = value.eval()
  194. return metrics
  195. def _get_scaling_sens(self):
  196. """get the scaling sens"""
  197. scaling_sens = 1
  198. if self._loss_scale_manager is not None:
  199. scaling_sens = self._loss_scale_manager.get_loss_scale()
  200. if self._parallel_mode == ParallelMode.DATA_PARALLEL:
  201. scaling_sens /= self._device_number
  202. return scaling_sens
  203. def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
  204. """Initializes dataset."""
  205. need_wrap = False
  206. if dataset_sink_mode:
  207. # remove later to deal with loop sink
  208. if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
  209. and not context.get_context("enable_ge"):
  210. need_wrap = True
  211. if not is_train:
  212. dataset.__loop_size__ = 1
  213. dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
  214. # remove later to deal with loop sink
  215. if need_wrap:
  216. network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
  217. network.set_train(is_train)
  218. network.phase = phase
  219. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  220. network.set_auto_parallel()
  221. return dataset_helper, network
  222. def init(self, train_dataset=None, valid_dataset=None):
  223. """
  224. Initialize compute graphs and data graphs with the sink mode.
  225. Note:
  226. Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
  227. Args:
  228. train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
  229. initialized. Default: None.
  230. valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
  231. will be initialized, and `metrics` in `Model` can not be None. Default: None.
  232. Examples:
  233. >>> train_dataset = get_train_dataset()
  234. >>> valid_dataset = get_valid_dataset()
  235. >>> net = Net()
  236. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  237. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  238. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
  239. >>> model.init(train_dataset, valid_dataset)
  240. >>> model.train(2, train_dataset)
  241. >>> model.eval(valid_dataset)
  242. """
  243. if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
  244. raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
  245. if not train_dataset and not valid_dataset:
  246. raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
  247. _device_number_check(self._parallel_mode, self._device_number)
  248. if train_dataset:
  249. _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
  250. self._train_network.set_train()
  251. self._train_network.phase = 'train'
  252. if self._parameter_broadcast:
  253. self._train_network.set_broadcast_flag()
  254. train_dataset.__no_send__ = True
  255. train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
  256. is_train=True,
  257. phase='train',
  258. dataset=train_dataset,
  259. dataset_sink_mode=True)
  260. self._train_network = train_network
  261. for inputs in train_dataset_helper:
  262. self._train_network.compile(*inputs)
  263. break
  264. if valid_dataset:
  265. if not self._metric_fns:
  266. raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
  267. self._eval_network.set_train(False)
  268. self._eval_network.phase = 'eval'
  269. valid_dataset.__no_send__ = True
  270. valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
  271. is_train=False,
  272. phase='eval',
  273. dataset=valid_dataset,
  274. dataset_sink_mode=True)
  275. self._eval_network = eval_network
  276. for inputs in valid_dataset_helper:
  277. self._eval_network.compile(*inputs)
  278. break
  279. def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
  280. """
  281. Training.
  282. Args:
  283. epoch (int): Total number of iterations on the data.
  284. train_dataset (Dataset): A training dataset iterator. If there is no
  285. loss_fn, a tuple with multiple data (data1, data2, data3, ...) will be
  286. returned and passed to the network. Otherwise, a tuple (data, label) will
  287. be returned. The data and label would be passed to the network and loss
  288. function respectively.
  289. callbacks (list): List of callback objects which should be executed while training. Default: None.
  290. dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel.
  291. Default: True.
  292. Configure pynative mode, the training process will be performed with
  293. dataset not sink.
  294. sink_size (int): Control the amount of data in each sink. Default: -1.
  295. """
  296. epoch = check_int_positive(epoch)
  297. self._train_network.set_train()
  298. if self._parameter_broadcast:
  299. self._train_network.set_broadcast_flag()
  300. cb_params = _InternalCallbackParam()
  301. cb_params.train_network = self._train_network
  302. cb_params.epoch_num = epoch
  303. if dataset_sink_mode and sink_size > 0:
  304. cb_params.batch_num = sink_size
  305. else:
  306. cb_params.batch_num = train_dataset.get_dataset_size()
  307. cb_params.mode = "train"
  308. cb_params.loss_fn = self._loss_fn
  309. cb_params.optimizer = self._optimizer
  310. cb_params.parallel_mode = self._parallel_mode
  311. cb_params.device_number = self._device_number
  312. cb_params.train_dataset = train_dataset
  313. cb_params.list_callback = self._transform_callbacks(callbacks)
  314. cb_params.train_dataset_element = None
  315. cb_params.network = self._network
  316. ms_role = os.getenv("MS_ROLE")
  317. if ms_role in ("MS_PSERVER", "MS_SCHED"):
  318. epoch = 1
  319. # build callback list
  320. with _CallbackManager(callbacks) as list_callback:
  321. if not dataset_sink_mode:
  322. self._train_process(epoch, train_dataset, list_callback, cb_params)
  323. elif context.get_context("mode") == context.PYNATIVE_MODE:
  324. logger.warning("The pynative mode cannot support dataset sink mode currently."
  325. "So the training process will be performed with dataset not sink.")
  326. self._train_process(epoch, train_dataset, list_callback, cb_params)
  327. else:
  328. self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
  329. @staticmethod
  330. def _transform_callbacks(callbacks):
  331. """Transform callback to a list."""
  332. if callbacks is None:
  333. return []
  334. if isinstance(callbacks, Iterable):
  335. return list(callbacks)
  336. return [callbacks]
  337. def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
  338. """
  339. Training process. The data would be passed to network through dataset channel.
  340. Args:
  341. epoch (int): Total number of iterations on the data.
  342. train_dataset (Dataset): A training dataset iterator. If there is no
  343. loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
  344. returned and passed to the network. Otherwise, a tuple (data, label) should
  345. be returned. The data and label would be passed to the network and loss
  346. function respectively.
  347. list_callback (Callback): Executor of callback list. Default: None.
  348. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  349. sink_size (int): Control the amount of data in each sink. Default: -1.
  350. """
  351. if sink_size == -1:
  352. epoch_num = epoch
  353. else:
  354. epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
  355. dataset_helper, train_network = self._exec_preprocess(self._train_network,
  356. is_train=True,
  357. phase='train',
  358. dataset=train_dataset,
  359. dataset_sink_mode=True,
  360. sink_size=sink_size,
  361. epoch_num=epoch_num)
  362. self._train_network = train_network
  363. cb_params.train_network = self._train_network
  364. cb_params.cur_step_num = 0
  365. run_context = RunContext(cb_params)
  366. list_callback.begin(run_context)
  367. # used to stop training for early stop, such as stopAtTIme or stopATStep
  368. should_stop = False
  369. for i in range(epoch):
  370. cb_params.cur_epoch_num = i + 1
  371. list_callback.epoch_begin(run_context)
  372. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
  373. for inputs in dataset_helper:
  374. if _need_to_full() and context.get_context("device_target") == "GPU":
  375. inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
  376. list_callback.step_begin(run_context)
  377. outputs = self._train_network(*inputs)
  378. cb_params.cur_step_num += dataset_helper.sink_size()
  379. cb_params.net_outputs = outputs
  380. list_callback.step_end(run_context)
  381. list_callback.epoch_end(run_context)
  382. should_stop = should_stop or run_context.get_stop_requested()
  383. if should_stop:
  384. break
  385. dataset_helper.stop_send()
  386. list_callback.end(run_context)
  387. def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
  388. """
  389. Training process. The data would be passed to network directly.
  390. Args:
  391. epoch (int): Total number of iterations on the data.
  392. train_dataset (Dataset): A training dataset iterator. If there is no
  393. loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
  394. returned and passed to the network. Otherwise, a tuple (data, label) should
  395. be returned. The data and label would be passed to the network and loss
  396. function respectively.
  397. list_callback (Callback): Executor of callback list. Default: None.
  398. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  399. """
  400. dataset_helper, _ = self._exec_preprocess(self._train_network,
  401. is_train=True,
  402. phase='train',
  403. dataset=train_dataset,
  404. dataset_sink_mode=False)
  405. cb_params.cur_step_num = 0
  406. run_context = RunContext(cb_params)
  407. list_callback.begin(run_context)
  408. # used to stop training for early stop, such as stopAtTIme or stopATStep
  409. should_stop = False
  410. for i in range(epoch):
  411. cb_params.cur_epoch_num = i + 1
  412. list_callback.epoch_begin(run_context)
  413. for next_element in dataset_helper:
  414. len_element = len(next_element)
  415. if self._loss_fn and len_element != 2:
  416. raise ValueError("when loss_fn is not None, train_dataset should"
  417. "return two elements, but got {}".format(len_element))
  418. cb_params.cur_step_num += 1
  419. list_callback.step_begin(run_context)
  420. overflow = False
  421. if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
  422. scaling_sens = self._get_scaling_sens()
  423. next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
  424. cb_params.train_dataset_element = next_element
  425. outputs = self._train_network(*next_element)
  426. cb_params.net_outputs = outputs
  427. if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
  428. _, overflow, _ = outputs
  429. overflow = np.all(overflow.asnumpy())
  430. self._loss_scale_manager.update_loss_scale(overflow)
  431. list_callback.step_end(run_context)
  432. if os.getenv("MS_ROLE") == "MS_PSERVER":
  433. os._exit(0)
  434. should_stop = should_stop or run_context.get_stop_requested()
  435. if should_stop:
  436. break
  437. train_dataset.reset()
  438. list_callback.epoch_end(run_context)
  439. should_stop = should_stop or run_context.get_stop_requested()
  440. if should_stop:
  441. break
  442. list_callback.end(run_context)
  443. def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
  444. """
  445. Training API where the iteration is controlled by python front-end.
  446. When setting pynative mode, the training process will be performed with dataset not sink.
  447. Note:
  448. CPU is not supported when dataset_sink_mode is true.
  449. If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
  450. operation in dataset processing. Otherwise, errors could occur since the amount of data
  451. is not equal to the required amount of training .
  452. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
  453. of data will be transferred one by one. The limitation of data transmission per time is 256M.
  454. If sink_size > 0, each epoch the dataset can be traversed unlimited times until you get sink_size
  455. elements of the dataset. Next epoch continues to traverse from the end position of the previous traversal.
  456. Args:
  457. epoch (int): Generally, total number of iterations on the data per epoch.
  458. When dataset_sink_mode is set to true and sink_size>0, each epoch sink sink_size
  459. steps on the data instead of total number of iterations.
  460. train_dataset (Dataset): A training dataset iterator. If there is no
  461. loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
  462. returned and passed to the network. Otherwise, a tuple (data, label) should
  463. be returned. The data and label would be passed to the network and loss
  464. function respectively.
  465. callbacks (list): List of callback objects which should be executed while training. Default: None.
  466. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  467. Configure pynative mode, the training process will be performed with
  468. dataset not sink.
  469. sink_size (int): Control the amount of data in each sink.
  470. If sink_size = -1, sink the complete dataset for each epoch.
  471. If sink_size > 0, sink sink_size data for each epoch.
  472. If dataset_sink_mode is False, set sink_size as invalid. Default: -1.
  473. Examples:
  474. >>> dataset = get_dataset()
  475. >>> net = Net()
  476. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  477. >>> loss_scale_manager = FixedLossScaleManager()
  478. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  479. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  480. >>> model.train(2, dataset)
  481. """
  482. check_bool(dataset_sink_mode)
  483. check_int(sink_size)
  484. if sink_size < -1 or sink_size == 0:
  485. raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
  486. _device_number_check(self._parallel_mode, self._device_number)
  487. _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
  488. self._train(epoch,
  489. train_dataset,
  490. callbacks=callbacks,
  491. dataset_sink_mode=dataset_sink_mode,
  492. sink_size=sink_size)
  493. def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
  494. """
  495. Evaluation. The data would be passed to network through dataset channel.
  496. Args:
  497. valid_dataset (Dataset): Dataset to evaluate the model.
  498. list_callback (Callback): Executor of callback list. Default: None.
  499. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  500. Returns:
  501. Dict, which returns the loss value and metrics values for the model in the test mode.
  502. """
  503. run_context = RunContext(cb_params)
  504. dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
  505. is_train=False,
  506. phase='eval',
  507. dataset=valid_dataset,
  508. dataset_sink_mode=True)
  509. self._eval_network = eval_network
  510. cb_params.eval_network = self._eval_network
  511. list_callback.begin(run_context)
  512. for inputs in dataset_helper:
  513. cb_params.cur_step_num += 1
  514. list_callback.step_begin(run_context)
  515. outputs = self._eval_network(*inputs)
  516. cb_params.net_outputs = outputs
  517. list_callback.step_end(run_context)
  518. self._update_metrics(outputs)
  519. metrics = self._get_metrics()
  520. cb_params.metrics = metrics
  521. list_callback.end(run_context)
  522. return metrics
  523. def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
  524. """
  525. Evaluation. The data would be passed to network directly.
  526. Args:
  527. valid_dataset (Dataset): Dataset to evaluate the model.
  528. list_callback (Callback): Executor of callback list. Default: None.
  529. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  530. Returns:
  531. Dict, which returns the loss value and metrics values for the model in the test mode.
  532. """
  533. run_context = RunContext(cb_params)
  534. list_callback.begin(run_context)
  535. dataset_helper, _ = self._exec_preprocess(self._eval_network,
  536. is_train=False,
  537. phase='eval',
  538. dataset=valid_dataset,
  539. dataset_sink_mode=False)
  540. for next_element in dataset_helper:
  541. cb_params.cur_step_num += 1
  542. list_callback.step_begin(run_context)
  543. outputs = self._eval_network(*next_element)
  544. cb_params.net_outputs = outputs
  545. list_callback.step_end(run_context)
  546. self._update_metrics(outputs)
  547. valid_dataset.reset()
  548. metrics = self._get_metrics()
  549. cb_params.metrics = metrics
  550. list_callback.end(run_context)
  551. return metrics
  552. def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
  553. """
  554. Evaluation API where the iteration is controlled by python front-end.
  555. Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
  556. Note:
  557. CPU is not supported when dataset_sink_mode is true.
  558. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
  559. of data will be transferred one by one. The limitation of data transmission per time is 256M.
  560. Args:
  561. valid_dataset (Dataset): Dataset to evaluate the model.
  562. callbacks (list): List of callback objects which should be executed while training. Default: None.
  563. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  564. Returns:
  565. Dict, which returns the loss value and metrics values for the model in the test mode.
  566. Examples:
  567. >>> dataset = get_dataset()
  568. >>> net = Net()
  569. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  570. >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
  571. >>> model.eval(dataset)
  572. """
  573. check_bool(dataset_sink_mode)
  574. _device_number_check(self._parallel_mode, self._device_number)
  575. if not self._metric_fns:
  576. raise ValueError("metric fn can not be None or empty.")
  577. cb_params = _InternalCallbackParam()
  578. cb_params.eval_network = self._eval_network
  579. cb_params.valid_dataset = valid_dataset
  580. cb_params.batch_num = valid_dataset.get_dataset_size()
  581. cb_params.mode = "eval"
  582. cb_params.cur_step_num = 0
  583. cb_params.list_callback = self._transform_callbacks(callbacks)
  584. cb_params.network = self._network
  585. self._eval_network.set_train(mode=False)
  586. self._eval_network.phase = 'eval'
  587. self._clear_metrics()
  588. with _CallbackManager(callbacks) as list_callback:
  589. if dataset_sink_mode:
  590. return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
  591. return self._eval_process(valid_dataset, list_callback, cb_params)
  592. def predict(self, *predict_data):
  593. """
  594. Generate output predictions for the input samples.
  595. Data could be a single tensor, a list of tensor, or a tuple of tensor.
  596. Note:
  597. Batch data should be put together in one tensor.
  598. Args:
  599. predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
  600. Returns:
  601. Tensor, array(s) of predictions.
  602. Examples:
  603. >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  604. >>> model = Model(Net())
  605. >>> model.predict(input_data)
  606. """
  607. self._predict_network.set_train(False)
  608. check_input_data(*predict_data, data_class=Tensor)
  609. result = self._predict_network(*predict_data)
  610. check_output_data(result)
  611. return result
  612. __all__ = ["Model"]