You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.thor.rst 4.1 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. mindspore.nn.thor
  2. ==================
  3. .. py:class:: mindspore.nn.thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, use_nesterov=False, decay_filter=<function <lambda> at 0x0000029724CFA048>, split_indices=None, enable_clip_grad=False, frequency=100)
  4. 通过二阶算法THOR更新参数。
  5. 基于跟踪的、硬件驱动层定向的自然梯度下降计算(THOR)算法论文地址为:
  6. `THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation <https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
  7. 更新公式如下:
  8. .. math::
  9. \begin{array}{ll}
  10. & \textbf{Parameter:} \: \text{the learning rate } \gamma\text{, the damping parameter }\lambda \\
  11. & \textbf{Init:} \: \lambda \leftarrow 0 \\
  12. & A_{i-1}=\mathbb{E}\left[a_{i-1} a_{i-1}^{T}\right] \\
  13. & G_{i}=\mathbb{E}\left[D_{s_i} D_{s_i}^{T}\right] \\
  14. & w_{i}^{(k+1)} \leftarrow w_{i}^{(k)}-\gamma\left(\left(A_{i-1}^{(k)}+\lambda I\right)^{-1}
  15. \otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)}
  16. \end{array}
  17. :math:`a_{i-1}` 表示第i层的输入,它是上一层的激活。
  18. :math:`D_{s_i}` 表示第i层输出的loss函数的导数。
  19. :math:`I` 代表单位矩阵。
  20. :math:`\lambda` 表示 :math:`damping` 参数, :math:`g_i` 表示第i层的梯度。
  21. :math:`\otimes` 表示克罗内克尔积, :math:`\gamma` 表示学习率。
  22. .. note::
  23. 在分离参数组时,每个组的 `weight_decay` 将应用于对应参数。当不分离参数组时,优化器中的 `weight_decay` 将应用于名称中没有'beta'或 'gamma'的参数。
  24. 在分离参数组时,如果要集中梯度,请将grad_centralization设置为True,但集中梯度只能应用于卷积层的参数。
  25. 如果非卷积层的参数设置为True,则会报错。
  26. 为了提高参数组的性能,可以支持自定义参数的顺序。
  27. **参数:**
  28. - **net** (Cell) - 训练网络。
  29. - **learning_rate** (Tensor) - 学习率的值。
  30. - **damping** (Tensor) - 阻尼值。
  31. - **momentum** (float) - float类型的超参数,表示移动平均的动量。至少为0.0。
  32. - **weight_decay** (int, float) - 权重衰减(L2 penalty)。必须等于或大于0.0。默认值:0.0。
  33. - **loss_scale** (float) - loss损失缩放系数。必须大于0.0。一般情况下,使用默认值。默认值:1.0。
  34. - **batch_size** (int) - batch的大小。默认值:32。
  35. - **use_nesterov** (bool) - 启用Nesterov动量。默认值:False。
  36. - **decay_filter** (function) - 用于确定权重衰减应用于哪些层的函数,只有在weight_decay>0时才有效。默认值:lambda x: x.name not in []。
  37. - **split_indices** (list) - 按A/G层(A/G含义见上述公式)索引设置allreduce融合策略。仅在分布式计算中有效。ResNet50作为一个样本,A/G的层数分别为54层,当split_indices设置为[26,53]时,表示A/G被分成两组allreduce,一组为0~26层,另一组是27~53层。默认值:None。
  38. - **enable_clip_grad** (bool) - 是否剪切梯度。默认值:False。
  39. - **frequency** (int) - A/G和$A^{-1}/G^{-1}$的更新间隔。每隔frequency个step,A/G和$A^{-1}/G^{-1}$将更新一次。必须大于1。默认值:100。
  40. **输入:**
  41. - **gradients** (tuple[Tensor]) - 训练参数的梯度,矩阵维度与训练参数相同。
  42. **输出:**
  43. tuple[bool],所有元素都为True。
  44. **异常:**
  45. - **TypeError** - `learning_rate` 不是张量。
  46. - **TypeError** - `loss_scale` 、 `momentum` 或 `frequency` 不是浮点数。
  47. - **TypeError** - `weight_decay` 既不是浮点数也不是整数。
  48. - **TypeError** - `use_nesterov` 不是布尔值。
  49. - **TypeError** - `frequency` 不是整数。
  50. - **ValueError** - `loss_scale` 小于或等于0。
  51. - **ValueError** - `weight_decay` 或 `momentum` 小于0。
  52. - **ValueError** - `frequency` 小于2。