Browse Source

!1558 fix performance of bert

Merge pull request !1558 from chenhaozhe/bert-optimiaztion
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a7744bde25
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/optim/lamb.py
  2. +2
    -2
      mindspore/ops/_grad/grad_math_ops.py

+ 1
- 1
mindspore/nn/optim/lamb.py View File

@@ -180,7 +180,7 @@ class Lamb(Optimizer):
beta2=0.999, beta2=0.999,
eps=1e-6, eps=1e-6,
weight_decay=0.0, weight_decay=0.0,
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):


super(Lamb, self).__init__(start_learning_rate, params) super(Lamb, self).__init__(start_learning_rate, params)
if self.is_group: if self.is_group:


+ 2
- 2
mindspore/ops/_grad/grad_math_ops.py View File

@@ -194,8 +194,8 @@ def get_bprop_mul(self):
mul_func = P.Mul() mul_func = P.Mul()


def bprop(x, y, out, dout): def bprop(x, y, out, dout):
bc_dx = mul_func(dout, y)
bc_dy = mul_func(dout, x)
bc_dx = mul_func(y, dout)
bc_dy = mul_func(x, dout)
return binop_grad_common(x, y, bc_dx, bc_dy) return binop_grad_common(x, y, bc_dx, bc_dy)


return bprop return bprop


Loading…
Cancel
Save