|
|
|
@@ -402,6 +402,18 @@ def _parallel_check(): |
|
|
|
class AdaSumByGradWrapCell(Cell): |
|
|
|
r""" |
|
|
|
Enable the adasum in "auto_parallel/semi_auto_parallel" mode. |
|
|
|
The implementation of the Adaptive Summation (AdaSum) algorithm is calculated by gradients. |
|
|
|
See the paper `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
\begin{array}{ll} |
|
|
|
w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ |
|
|
|
w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + |
|
|
|
(1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ |
|
|
|
\end{array} |
|
|
|
|
|
|
|
In this implementation, :math:`g` represents the gradient of the weights, |
|
|
|
and the subscripts represent different devices in the data-parallel dimension. |
|
|
|
|
|
|
|
Note: |
|
|
|
When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required. |
|
|
|
@@ -456,6 +468,19 @@ class AdaSumByGradWrapCell(Cell): |
|
|
|
class AdaSumByDeltaWeightWrapCell(Cell): |
|
|
|
r""" |
|
|
|
Enable the adasum in "auto_parallel/semi_auto_parallel" mode. |
|
|
|
The implementation of the Adaptive Summation (AdaSum) algorithm is calculated based on the difference of weights |
|
|
|
before and after the updating of optimizer. |
|
|
|
See the paper `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
\begin{array}{ll} |
|
|
|
w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ |
|
|
|
w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + |
|
|
|
(1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ |
|
|
|
\end{array} |
|
|
|
|
|
|
|
In this implementation, :math:`g` represents the weight difference before and after the updating of optimizer, |
|
|
|
and the subscripts represent different devices in the data parallel dimension. |
|
|
|
|
|
|
|
Note: |
|
|
|
When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required. |
|
|
|
|