You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Optimizer.rst 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. mindspore.nn.Optimizer
  2. ======================
  3. .. py:class:: mindspore.nn.Optimizer(learning_rate, parameters, weight_decay=0.0, loss_scale=1.0)
  4. 用于参数更新的优化器基类。不要直接使用这个类,请实例化它的一个子类。
  5. 优化器支持参数分组。当参数分组时,每组参数均可配置不同的学习率(`lr` )、权重衰减(`weight_decay`)和梯度中心化(`grad_centralization`)策略。
  6. .. note::
  7. .. include:: mindspore.nn.optim_note_weight_decay.rst
  8. **参数:**
  9. - **learning_rate** (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
  10. .. include:: mindspore.nn.optim_arg_dynamic_lr.rst
  11. - **parameters (Union[list[Parameter], list[dict]])** - 必须是 `Parameter` 组成的列表或字典组成的列表。当列表元素是字典时,字典的键可以是"params"、"lr"、"weight_decay"、"grad_centralization"和"order_params":
  12. .. include:: mindspore.nn.optim_group_param.rst
  13. .. include:: mindspore.nn.optim_group_lr.rst
  14. .. include:: mindspore.nn.optim_group_weight_decay.rst
  15. .. include:: mindspore.nn.optim_group_gc.rst
  16. .. include:: mindspore.nn.optim_group_order.rst
  17. - **weight_decay** (Union[float, int]) - 权重衰减的整数或浮点值。必须等于或大于0。如果 `weight_decay` 是整数,它将被转换为浮点数。默认值:0.0。
  18. .. include:: mindspore.nn.optim_arg_loss_scale.rst
  19. **异常:**
  20. - **TypeError** - `learning_rate` 不是int、float、Tensor、Iterable或LearningRateSchedule。
  21. - **TypeError** - `parameters` 的元素不是Parameter或字典。
  22. - **TypeError** - `loss_scale` 不是float。
  23. - **TypeError** - `weight_decay` 不是float或int。
  24. - **ValueError** - `loss_scale` 小于或等于0。
  25. - **ValueError** - `weight_decay` 小于0。
  26. - **ValueError** - `learning_rate` 是一个Tensor,但是Tensor的维度大于1。
  27. **支持平台:**
  28. ``Ascend`` ``GPU`` ``CPU``
  29. .. py:method:: broadcast_params(optim_result)
  30. 按参数组的顺序进行参数广播。
  31. **参数:**
  32. - **optim_result** (bool) - 参数更新结果。该输入用来保证参数更新完成后才执行参数广播。
  33. **返回:**
  34. bool,状态标志。
  35. .. py:method:: decay_weight(gradients)
  36. 衰减权重。
  37. 一种减少深度学习神经网络模型过拟合的方法。继承 :class:`mindspore.nn.Optimizer` 自定义优化器时,可调用该接口进行权重衰减。
  38. **参数:**
  39. - **gradients** (tuple[Tensor]) - 网络参数的梯度,形状(shape)与网络参数相同。
  40. **返回:**
  41. tuple[Tensor],衰减权重后的梯度。
  42. .. py:method:: get_lr()
  43. 优化器调用该接口获取当前步骤(step)的学习率。继承 :class:`mindspore.nn.Optimizer` 自定义优化器时,可在参数更新前调用该接口获取学习率。
  44. **返回:**
  45. float,当前步骤的学习率。
  46. .. py:method:: get_lr_parameter(param)
  47. 用于在使用网络参数分组功能,且为不同组别配置不同的学习率时,获取指定参数的学习率。
  48. **参数:**
  49. - **param** (Union[Parameter, list[Parameter]]) - `Parameter` 或 `Parameter` 列表。
  50. **返回:**
  51. Parameter,单个 `Parameter` 或 `Parameter` 列表。如果使用了动态学习率,返回用于计算学习率的 `LearningRateSchedule` 或 `LearningRateSchedule` 列表。
  52. **样例:**
  53. >>> from mindspore import nn
  54. >>> net = Net()
  55. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  56. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  57. >>> group_params = [{'params': conv_params, 'lr': 0.05},
  58. ... {'params': no_conv_params, 'lr': 0.01}]
  59. >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
  60. >>> conv_lr = optim.get_lr_parameter(conv_params)
  61. >>> print(conv_lr[0].asnumpy())
  62. 0.05
  63. .. py:method:: gradients_centralization(gradients)
  64. 梯度中心化。
  65. 一种优化卷积层参数以提高深度学习神经网络模型训练速度的方法。继承 :class:`mindspore.nn.Optimizer` 自定义优化器时,可调用该接口进行梯度中心化。
  66. **参数:**
  67. - **gradients** (tuple[Tensor]) - 网络参数的梯度,形状(shape)与网络参数相同。
  68. **返回:**
  69. tuple[Tensor],梯度中心化后的梯度。
  70. .. py:method:: scale_grad(gradients)
  71. 用于在混合精度场景还原梯度。
  72. 继承 :class:`mindspore.nn.Optimizer` 自定义优化器时,可调用该接口还原梯度。
  73. **参数:**
  74. - **gradients** (tuple[Tensor]) - 网络参数的梯度,形状(shape)与网络参数相同。
  75. **返回:**
  76. tuple[Tensor],还原后的梯度。
  77. .. include:: mindspore.nn.optim_target_unique_for_sparse.rst