You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 35 kB

5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Model."""
  16. from collections.abc import Iterable
  17. import os
  18. import math
  19. import numpy as np
  20. from mindspore import log as logger
  21. from ..common.tensor import Tensor
  22. from ..nn.metrics import get_metrics
  23. from .._checkparam import check_input_data, check_output_data, Validator, check_int
  24. from .callback import _InternalCallbackParam, RunContext, _CallbackManager
  25. from .. import context
  26. from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
  27. _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
  28. from ..parallel._ps_context import _is_role_pserver, _is_role_sched
  29. from ..nn.metrics import Loss
  30. from .. import nn
  31. from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
  32. from ..context import ParallelMode
  33. from ..parallel._cost_model_context import _set_multi_subgraphs
  34. from .dataset_helper import DatasetHelper, connect_network_with_dataset
  35. from . import amp
  36. def _transfer_tensor_to_tuple(inputs):
  37. """
  38. If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
  39. """
  40. if isinstance(inputs, Tensor):
  41. return (inputs,)
  42. return inputs
  43. class Model:
  44. """
  45. High-Level API for Training or Testing.
  46. `Model` groups layers into an object with training and inference features.
  47. Args:
  48. network (Cell): A training or testing network.
  49. loss_fn (Cell): Objective function, if loss_fn is None, the
  50. network should contain the logic of loss and grads calculation, and the logic
  51. of parallel if needed. Default: None.
  52. optimizer (Cell): Optimizer for updating the weights. Default: None.
  53. metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
  54. training and testing. eg: {'accuracy', 'recall'}. Default: None.
  55. eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
  56. `eval_network`. Default: None.
  57. eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
  58. `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
  59. elements, including the positions of loss value, predicted value and label. The loss
  60. value would be passed to the `Loss` metric, the predicted value and label would be passed
  61. to other metric. Default: None.
  62. amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
  63. precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
  64. - O0: Do not change.
  65. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
  66. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
  67. - auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
  68. level to O3 Ascend. The recommended level is choose by the export experience, cannot
  69. always generalize. User should specify the level for special network.
  70. O2 is recommended on GPU, O3 is recommended on Ascend.
  71. loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
  72. scale the loss by LossScaleManager and optimizer can not be None.It is a key argument.
  73. e.g. Use `loss_scale_manager=None` to set the value.
  74. keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
  75. will be overwritten. Default: True.
  76. Examples:
  77. >>> class Net(nn.Cell):
  78. >>> def __init__(self):
  79. >>> super(Net, self).__init__()
  80. >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  81. >>> self.bn = nn.BatchNorm2d(64)
  82. >>> self.relu = nn.ReLU()
  83. >>> self.flatten = nn.Flatten()
  84. >>> self.fc = nn.Dense(64*224*224, 12) # padding=0
  85. >>>
  86. >>> def construct(self, x):
  87. >>> x = self.conv(x)
  88. >>> x = self.bn(x)
  89. >>> x = self.relu(x)
  90. >>> x = self.flatten(x)
  91. >>> out = self.fc(x)
  92. >>> return out
  93. >>>
  94. >>> net = Net()
  95. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  96. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  97. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  98. >>> dataset = get_dataset()
  99. >>> model.train(2, dataset)
  100. """
  101. def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
  102. eval_indexes=None, amp_level="O0", **kwargs):
  103. self._network = network
  104. self._loss_fn = loss_fn
  105. self._optimizer = optimizer
  106. self._loss_scale_manager = None
  107. self._loss_scale_manager_set = False
  108. self._keep_bn_fp32 = True
  109. self._check_kwargs(kwargs)
  110. self._amp_level = amp_level
  111. self._process_amp_args(kwargs)
  112. self._parallel_mode = _get_parallel_mode()
  113. self._device_number = _get_device_num()
  114. self._global_rank = _get_global_rank()
  115. self._parameter_broadcast = _get_parameter_broadcast()
  116. self._train_network = self._build_train_network()
  117. self._build_eval_network(metrics, eval_network, eval_indexes)
  118. self._build_predict_network()
  119. def _process_amp_args(self, kwargs):
  120. if self._amp_level in ["O0", "O3"]:
  121. self._keep_bn_fp32 = False
  122. if 'keep_batchnorm_fp32' in kwargs:
  123. self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
  124. if 'loss_scale_manager' in kwargs:
  125. self._loss_scale_manager = kwargs['loss_scale_manager']
  126. self._loss_scale_manager_set = True
  127. def _check_kwargs(self, kwargs):
  128. for arg in kwargs:
  129. if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
  130. raise ValueError(f"Unsupported arg '{arg}'")
  131. def _build_train_network(self):
  132. """Build train network"""
  133. network = self._network
  134. if self._loss_scale_manager is not None and self._optimizer is None:
  135. raise ValueError("Optimizer can not be None when set loss_scale_manager.")
  136. if self._optimizer:
  137. if self._loss_scale_manager_set:
  138. network = amp.build_train_network(network,
  139. self._optimizer,
  140. self._loss_fn,
  141. level=self._amp_level,
  142. loss_scale_manager=self._loss_scale_manager,
  143. keep_batchnorm_fp32=self._keep_bn_fp32)
  144. else:
  145. network = amp.build_train_network(network,
  146. self._optimizer,
  147. self._loss_fn,
  148. level=self._amp_level,
  149. keep_batchnorm_fp32=self._keep_bn_fp32)
  150. elif self._loss_fn:
  151. network = nn.WithLossCell(network, self._loss_fn)
  152. # If need to check if loss_fn is not None, but optimizer is None
  153. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  154. network.set_auto_parallel()
  155. if self._optimizer is None:
  156. # In this case, multiple optimizer(s) is supposed to be included in 'self._network'
  157. _set_multi_subgraphs()
  158. return network
  159. def _build_eval_network(self, metrics, eval_network, eval_indexes):
  160. """Build the network for evaluation."""
  161. self._metric_fns = get_metrics(metrics)
  162. if not self._metric_fns:
  163. return
  164. if eval_network is not None:
  165. if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
  166. raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
  167. must be three. But got {}".format(eval_indexes))
  168. self._eval_network = eval_network
  169. self._eval_indexes = eval_indexes
  170. else:
  171. if self._loss_fn is None:
  172. raise ValueError("loss_fn can not be None.")
  173. self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
  174. self._eval_indexes = [0, 1, 2]
  175. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  176. if self._optimizer:
  177. self._eval_network = _VirtualDatasetCell(self._eval_network)
  178. if self._optimizer is None:
  179. # In this case, multiple optimizer(s) is supposed to be included in 'self._network'
  180. _set_multi_subgraphs()
  181. self._eval_network.set_auto_parallel()
  182. def _build_predict_network(self):
  183. """Build the network for prediction."""
  184. self._predict_network = self._network
  185. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  186. self._predict_network = _VirtualDatasetCell(self._network)
  187. # Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
  188. self._predict_network.set_auto_parallel()
  189. def _clear_metrics(self):
  190. """Clear metrics local values."""
  191. for metric in self._metric_fns.values():
  192. metric.clear()
  193. def _update_metrics(self, outputs):
  194. """Update metrics local values."""
  195. if not isinstance(outputs, tuple):
  196. raise ValueError("The `outputs` is not tuple.")
  197. if self._eval_indexes is not None and len(outputs) < 3:
  198. raise ValueError("The length of `outputs` must be greater than or equal to 3, \
  199. but got {}".format(len(outputs)))
  200. for metric in self._metric_fns.values():
  201. if self._eval_indexes is None:
  202. metric.update(*outputs)
  203. else:
  204. if isinstance(metric, Loss):
  205. metric.update(outputs[self._eval_indexes[0]])
  206. else:
  207. metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
  208. def _get_metrics(self):
  209. """Get metrics local values."""
  210. metrics = dict()
  211. for key, value in self._metric_fns.items():
  212. metrics[key] = value.eval()
  213. return metrics
  214. def _get_scaling_sens(self):
  215. """get the scaling sens"""
  216. scaling_sens = 1
  217. if self._loss_scale_manager is not None:
  218. scaling_sens = self._loss_scale_manager.get_loss_scale()
  219. if self._parallel_mode == ParallelMode.DATA_PARALLEL:
  220. scaling_sens /= self._device_number
  221. return scaling_sens
  222. def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
  223. """Initializes dataset."""
  224. if dataset_sink_mode and not is_train:
  225. dataset.__loop_size__ = 1
  226. dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
  227. if dataset_sink_mode:
  228. network = connect_network_with_dataset(network, dataset_helper)
  229. network.set_train(is_train)
  230. network.phase = phase
  231. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  232. network.set_auto_parallel()
  233. return dataset_helper, network
  234. def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1):
  235. """
  236. Initialize compute graphs and data graphs with the sink mode.
  237. Note:
  238. Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
  239. Args:
  240. train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
  241. initialized. Default: None.
  242. valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
  243. will be initialized, and `metrics` in `Model` can not be None. Default: None.
  244. sink_size (int): Control the amount of data in each sink. Default: -1.
  245. """
  246. if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
  247. raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
  248. if not train_dataset and not valid_dataset:
  249. raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
  250. _device_number_check(self._parallel_mode, self._device_number)
  251. if train_dataset:
  252. _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
  253. if self._parameter_broadcast:
  254. self._train_network.set_broadcast_flag()
  255. train_dataset.__no_send__ = True
  256. train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
  257. is_train=True,
  258. phase='train',
  259. dataset=train_dataset,
  260. dataset_sink_mode=True,
  261. sink_size=sink_size)
  262. self._train_network = train_network
  263. for inputs in train_dataset_helper:
  264. self._train_network.compile(*inputs)
  265. break
  266. if valid_dataset:
  267. if not self._metric_fns:
  268. raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
  269. valid_dataset.__no_send__ = True
  270. valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
  271. is_train=False,
  272. phase='eval',
  273. dataset=valid_dataset,
  274. dataset_sink_mode=True)
  275. self._eval_network = eval_network
  276. for inputs in valid_dataset_helper:
  277. self._eval_network.compile(*inputs)
  278. break
  279. def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
  280. """
  281. Training.
  282. Args:
  283. epoch (int): Total number of iterations on the data.
  284. train_dataset (Dataset): A training dataset iterator. If there is no
  285. loss_fn, a tuple with multiple data (data1, data2, data3, ...) will be
  286. returned and passed to the network. Otherwise, a tuple (data, label) will
  287. be returned. The data and label would be passed to the network and loss
  288. function respectively.
  289. callbacks (list): List of callback objects which should be executed while training. Default: None.
  290. dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel.
  291. Default: True.
  292. Configure pynative mode or CPU, the training process will be performed with
  293. dataset not sink.
  294. sink_size (int): Control the amount of data in each sink. Default: -1.
  295. """
  296. epoch = Validator.check_positive_int(epoch)
  297. if self._parameter_broadcast:
  298. self._train_network.set_broadcast_flag()
  299. cb_params = _InternalCallbackParam()
  300. cb_params.train_network = self._train_network
  301. cb_params.epoch_num = epoch
  302. if dataset_sink_mode and sink_size > 0:
  303. cb_params.batch_num = sink_size
  304. else:
  305. cb_params.batch_num = train_dataset.get_dataset_size()
  306. cb_params.mode = "train"
  307. cb_params.loss_fn = self._loss_fn
  308. cb_params.optimizer = self._optimizer
  309. cb_params.parallel_mode = self._parallel_mode
  310. cb_params.device_number = self._device_number
  311. cb_params.train_dataset = train_dataset
  312. cb_params.list_callback = self._transform_callbacks(callbacks)
  313. cb_params.train_dataset_element = None
  314. cb_params.network = self._network
  315. if _is_role_pserver() or _is_role_sched():
  316. epoch = 1
  317. # build callback list
  318. with _CallbackManager(callbacks) as list_callback:
  319. if not dataset_sink_mode:
  320. self._train_process(epoch, train_dataset, list_callback, cb_params)
  321. elif context.get_context("mode") == context.PYNATIVE_MODE or context.get_context("device_target") == "CPU":
  322. logger.warning("The pynative mode and CPU cannot support dataset sink mode currently."
  323. "So the training process will be performed with dataset not sink.")
  324. self._train_process(epoch, train_dataset, list_callback, cb_params)
  325. else:
  326. self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
  327. @staticmethod
  328. def _transform_callbacks(callbacks):
  329. """Transform callback to a list."""
  330. if callbacks is None:
  331. return []
  332. if isinstance(callbacks, Iterable):
  333. return list(callbacks)
  334. return [callbacks]
  335. def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
  336. """
  337. Training process. The data would be passed to network through dataset channel.
  338. Args:
  339. epoch (int): Total number of iterations on the data.
  340. train_dataset (Dataset): A training dataset iterator. If there is no
  341. loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
  342. returned and passed to the network. Otherwise, a tuple (data, label) should
  343. be returned. The data and label would be passed to the network and loss
  344. function respectively.
  345. list_callback (Callback): Executor of callback list. Default: None.
  346. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  347. sink_size (int): Control the amount of data in each sink. Default: -1.
  348. """
  349. if sink_size == -1:
  350. epoch_num = epoch
  351. else:
  352. epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
  353. dataset_helper, train_network = self._exec_preprocess(self._train_network,
  354. is_train=True,
  355. phase='train',
  356. dataset=train_dataset,
  357. dataset_sink_mode=True,
  358. sink_size=sink_size,
  359. epoch_num=epoch_num)
  360. self._train_network = train_network
  361. cb_params.train_network = self._train_network
  362. cb_params.cur_step_num = 0
  363. run_context = RunContext(cb_params)
  364. list_callback.begin(run_context)
  365. # used to stop training for early stop, such as stopAtTIme or stopATStep
  366. should_stop = False
  367. for i in range(epoch):
  368. cb_params.cur_epoch_num = i + 1
  369. list_callback.epoch_begin(run_context)
  370. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
  371. for inputs in dataset_helper:
  372. cb_params.train_dataset_element = inputs
  373. list_callback.step_begin(run_context)
  374. outputs = self._train_network(*inputs)
  375. cb_params.cur_step_num += dataset_helper.sink_size()
  376. cb_params.net_outputs = outputs
  377. list_callback.step_end(run_context)
  378. dataset_helper.continue_send()
  379. list_callback.epoch_end(run_context)
  380. should_stop = should_stop or run_context.get_stop_requested()
  381. if should_stop:
  382. break
  383. dataset_helper.stop_send()
  384. list_callback.end(run_context)
  385. def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
  386. """
  387. Training process. The data would be passed to network directly.
  388. Args:
  389. epoch (int): Total number of iterations on the data.
  390. train_dataset (Dataset): A training dataset iterator. If there is no
  391. loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
  392. returned and passed to the network. Otherwise, a tuple (data, label) should
  393. be returned. The data and label would be passed to the network and loss
  394. function respectively.
  395. list_callback (Callback): Executor of callback list. Default: None.
  396. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  397. """
  398. dataset_helper, _ = self._exec_preprocess(self._train_network,
  399. is_train=True,
  400. phase='train',
  401. dataset=train_dataset,
  402. dataset_sink_mode=False,
  403. epoch_num=epoch)
  404. cb_params.cur_step_num = 0
  405. run_context = RunContext(cb_params)
  406. list_callback.begin(run_context)
  407. # used to stop training for early stop, such as stopAtTIme or stopATStep
  408. should_stop = False
  409. for i in range(epoch):
  410. cb_params.cur_epoch_num = i + 1
  411. list_callback.epoch_begin(run_context)
  412. for next_element in dataset_helper:
  413. len_element = len(next_element)
  414. next_element = _transfer_tensor_to_tuple(next_element)
  415. if self._loss_fn and len_element != 2:
  416. raise ValueError("when loss_fn is not None, train_dataset should"
  417. "return two elements, but got {}".format(len_element))
  418. cb_params.cur_step_num += 1
  419. cb_params.train_dataset_element = next_element
  420. list_callback.step_begin(run_context)
  421. outputs = self._train_network(*next_element)
  422. cb_params.net_outputs = outputs
  423. if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
  424. _, overflow, _ = outputs
  425. overflow = np.all(overflow.asnumpy())
  426. self._loss_scale_manager.update_loss_scale(overflow)
  427. list_callback.step_end(run_context)
  428. if _is_role_pserver():
  429. os._exit(0)
  430. should_stop = should_stop or run_context.get_stop_requested()
  431. if should_stop:
  432. break
  433. train_dataset.reset()
  434. list_callback.epoch_end(run_context)
  435. should_stop = should_stop or run_context.get_stop_requested()
  436. if should_stop:
  437. break
  438. list_callback.end(run_context)
  439. def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
  440. """
  441. Training API where the iteration is controlled by python front-end.
  442. When setting pynative mode or CPU, the training process will be performed with dataset not sink.
  443. Note:
  444. If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
  445. operation in dataset processing. Otherwise, errors could occur since the amount of data
  446. is not equal to the required amount of training .
  447. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
  448. of data will be transferred one by one. The limitation of data transmission per time is 256M.
  449. If sink_size > 0, each epoch the dataset can be traversed unlimited times until you get sink_size
  450. elements of the dataset. Next epoch continues to traverse from the end position of the previous traversal.
  451. Args:
  452. epoch (int): Generally, total number of iterations on the data per epoch.
  453. When dataset_sink_mode is set to true and sink_size>0, each epoch sink sink_size
  454. steps on the data instead of total number of iterations.
  455. train_dataset (Dataset): A training dataset iterator. If there is no
  456. loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
  457. returned and passed to the network. Otherwise, a tuple (data, label) should
  458. be returned. The data and label would be passed to the network and loss
  459. function respectively.
  460. callbacks (list): List of callback objects which should be executed while training. Default: None.
  461. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  462. Configure pynative mode or CPU, the training process will be performed with
  463. dataset not sink.
  464. sink_size (int): Control the amount of data in each sink.
  465. If sink_size = -1, sink the complete dataset for each epoch.
  466. If sink_size > 0, sink sink_size data for each epoch.
  467. If dataset_sink_mode is False, set sink_size as invalid. Default: -1.
  468. Examples:
  469. >>> dataset = get_dataset()
  470. >>> net = Net()
  471. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  472. >>> loss_scale_manager = FixedLossScaleManager()
  473. >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  474. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  475. >>> model.train(2, dataset)
  476. """
  477. dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
  478. if sink_size == -1:
  479. sink_size = train_dataset.get_dataset_size()
  480. check_int(sink_size)
  481. if sink_size < -1 or sink_size == 0:
  482. raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
  483. _device_number_check(self._parallel_mode, self._device_number)
  484. _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
  485. self._train(epoch,
  486. train_dataset,
  487. callbacks=callbacks,
  488. dataset_sink_mode=dataset_sink_mode,
  489. sink_size=sink_size)
  490. def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
  491. """
  492. Evaluation. The data would be passed to network through dataset channel.
  493. Args:
  494. valid_dataset (Dataset): Dataset to evaluate the model.
  495. list_callback (Callback): Executor of callback list. Default: None.
  496. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  497. Returns:
  498. Dict, which returns the loss value and metrics values for the model in the test mode.
  499. """
  500. run_context = RunContext(cb_params)
  501. dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
  502. is_train=False,
  503. phase='eval',
  504. dataset=valid_dataset,
  505. dataset_sink_mode=True)
  506. self._eval_network = eval_network
  507. cb_params.eval_network = self._eval_network
  508. list_callback.begin(run_context)
  509. for inputs in dataset_helper:
  510. cb_params.cur_step_num += 1
  511. list_callback.step_begin(run_context)
  512. outputs = self._eval_network(*inputs)
  513. cb_params.net_outputs = outputs
  514. list_callback.step_end(run_context)
  515. self._update_metrics(outputs)
  516. metrics = self._get_metrics()
  517. cb_params.metrics = metrics
  518. list_callback.end(run_context)
  519. return metrics
  520. def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
  521. """
  522. Evaluation. The data would be passed to network directly.
  523. Args:
  524. valid_dataset (Dataset): Dataset to evaluate the model.
  525. list_callback (Callback): Executor of callback list. Default: None.
  526. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
  527. Returns:
  528. Dict, which returns the loss value and metrics values for the model in the test mode.
  529. """
  530. run_context = RunContext(cb_params)
  531. list_callback.begin(run_context)
  532. dataset_helper, _ = self._exec_preprocess(self._eval_network,
  533. is_train=False,
  534. phase='eval',
  535. dataset=valid_dataset,
  536. dataset_sink_mode=False)
  537. for next_element in dataset_helper:
  538. cb_params.cur_step_num += 1
  539. list_callback.step_begin(run_context)
  540. next_element = _transfer_tensor_to_tuple(next_element)
  541. outputs = self._eval_network(*next_element)
  542. cb_params.net_outputs = outputs
  543. list_callback.step_end(run_context)
  544. self._update_metrics(outputs)
  545. valid_dataset.reset()
  546. metrics = self._get_metrics()
  547. cb_params.metrics = metrics
  548. list_callback.end(run_context)
  549. return metrics
  550. def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
  551. """
  552. Evaluation API where the iteration is controlled by python front-end.
  553. Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode.
  554. Note:
  555. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
  556. of data will be transferred one by one. The limitation of data transmission per time is 256M.
  557. Args:
  558. valid_dataset (Dataset): Dataset to evaluate the model.
  559. callbacks (list): List of callback objects which should be executed while training. Default: None.
  560. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
  561. Returns:
  562. Dict, which returns the loss value and metrics values for the model in the test mode.
  563. Examples:
  564. >>> dataset = get_dataset()
  565. >>> net = Net()
  566. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  567. >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
  568. >>> model.eval(dataset)
  569. """
  570. dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
  571. _device_number_check(self._parallel_mode, self._device_number)
  572. if not self._metric_fns:
  573. raise ValueError("metric fn can not be None or empty.")
  574. cb_params = _InternalCallbackParam()
  575. cb_params.eval_network = self._eval_network
  576. cb_params.valid_dataset = valid_dataset
  577. cb_params.batch_num = valid_dataset.get_dataset_size()
  578. cb_params.mode = "eval"
  579. cb_params.cur_step_num = 0
  580. cb_params.list_callback = self._transform_callbacks(callbacks)
  581. cb_params.network = self._network
  582. self._clear_metrics()
  583. if context.get_context("device_target") == "CPU":
  584. dataset_sink_mode = False
  585. logger.warning("CPU cannot support dataset sink mode currently."
  586. "So the evaluating process will be performed with dataset non-sink mode.")
  587. with _CallbackManager(callbacks) as list_callback:
  588. if dataset_sink_mode:
  589. return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
  590. return self._eval_process(valid_dataset, list_callback, cb_params)
  591. def predict(self, *predict_data):
  592. """
  593. Generate output predictions for the input samples.
  594. Data could be a single tensor, a list of tensor, or a tuple of tensor.
  595. Note:
  596. Batch data should be put together in one tensor.
  597. Args:
  598. predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
  599. Returns:
  600. Tensor, array(s) of predictions.
  601. Examples:
  602. >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  603. >>> model = Model(Net())
  604. >>> model.predict(input_data)
  605. """
  606. self._predict_network.set_train(False)
  607. check_input_data(*predict_data, data_class=Tensor)
  608. result = self._predict_network(*predict_data)
  609. check_output_data(result)
  610. return result
  611. __all__ = ["Model"]