You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.LARS.rst 2.2 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. mindspore.nn.LARS
  2. ==================
  3. .. py:class:: mindspore.nn.LARS(*args, **kwargs)
  4. 使用LARSUpdate算子实现LARS算法。
  5. LARS算法采用大量的优化技术。详见论文 `LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS <https://arxiv.org/abs/1708.03888>`_。
  6. 更新公式如下:
  7. .. math::
  8. \begin{array}{ll} \\
  9. \lambda = \frac{\theta \text{ * } || \omega || } \\
  10. {|| g_{t} || \text{ + } \delta \text{ * } || \omega || } \\
  11. \lambda =
  12. \begin{cases}
  13. \min(\frac{\lambda}{\alpha }, 1)
  14. & \text{ if } clip = True \\
  15. \lambda
  16. & \text{ otherwise }
  17. \end{cases}\\
  18. g_{t+1} = \lambda * (g_{t} + \delta * \omega)
  19. \end{array}
  20. :math:`\theta` 表示 `coefficient` ,:math:`\omega` 表示网络参数,:math:`g` 表示 `gradients`,:math:`t` 表示当前step,:math:`\delta` 表示 `optimizer` 配置的 `weight_decay` ,:math:`\alpha` 表示 `optimizer` 配置的 `learning_rate` ,:math:`clip` 表示 `use_clip`。
  21. **参数:**
  22. - **optimizer** (Optimizer) - 待封装和修改梯度的MindSpore优化器。
  23. - **epsilon** (float) - 将添加到分母中,提高数值稳定性。默认值:1e-05。
  24. - **coefficient** (float) - 计算局部学习速率的信任系数。默认值:0.001。
  25. - **use_clip** (bool) - 计算局部学习速率时是否裁剪。默认值:False。
  26. - **lars_filter** (Function) - 用于指定使用LARS算法的网络参数。默认值:lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name。
  27. **输入:**
  28. - **gradients** (tuple[Tensor]) - 优化器中 `params` 的梯度,shape与优化器中的 `params` 相同。
  29. **输出:**
  30. Union[Tensor[bool], tuple[Parameter]],取决于 `optimizer` 的输出。
  31. **支持平台:**
  32. ``Ascend`` ``CPU``
  33. **样例:**
  34. >>> net = Net()
  35. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  36. >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
  37. >>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)
  38. >>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)