You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.AdaSumByGradWrapCell.rst 1.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637
  1. mindspore.nn.AdaSumByGradWrapCell
  2. =================================
  3. .. py:class:: mindspore.nn.AdaSumByGradWrapCell(optimizer)
  4. Adaptive Summation (AdaSum)算法的实现,根据梯度计算。应用于semi_auto_parallel/auto_parallel模式。
  5. 请参阅论文 `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_。
  6. 公式如下:
  7. .. math::
  8. \begin{array}{ll}
  9. w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\
  10. w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\
  11. \end{array}
  12. 在本实现中, :math:`g` 代表权重的梯度,下标代表数据并行维度下不同的设备。
  13. .. note::
  14. 本接口推荐应用于半自动并行或者全自动并行模式。针对数据并行模式,推荐使用mindspore.boost功能以使用AdaSum。
  15. 使用本接口时,训练的卡的数量必须是2的幂,并且至少需要16张卡。目前,使用本接口时不支持优化器并行和流水线并行。
  16. **参数:**
  17. - **optimizer** (nn.optimizer) - 必须是单输入的优化器。
  18. **输入:**
  19. - **grads** (tuple[Tensor]) - `params` 的梯度,形状(shape)与 `params` 相同,与所传优化器的输入一致。
  20. **异常:**
  21. - **RuntimeError** - `parallel_mode` 使用了 `stand_alone` 模式, AdaSum仅支持在分布式场景下使用。
  22. - **RuntimeError** - 同时使用了优化器并行, 暂时不支持在优化器并行场景下使用AdaSum。
  23. - **RuntimeError** - 同时使用了流水线并行, 暂时不支持在流水线并行场景下使用AdaSum。
  24. - **RuntimeError** - `device_num` 不是2的幂,或者小于16。