| @@ -687,7 +687,7 @@ bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &va | |||||
| MS_EXCEPTION_IF_NULL(equiv1_node); | MS_EXCEPTION_IF_NULL(equiv1_node); | ||||
| auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); | auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); | ||||
| MS_EXCEPTION_IF_NULL(equiv2_node); | MS_EXCEPTION_IF_NULL(equiv2_node); | ||||
| return equiv1_node == equiv2_node; | |||||
| return *equiv1_node == *equiv2_node; | |||||
| } | } | ||||
| AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { | AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { | ||||
| @@ -180,7 +180,7 @@ class Lamb(Optimizer): | |||||
| beta2=0.999, | beta2=0.999, | ||||
| eps=1e-6, | eps=1e-6, | ||||
| weight_decay=0.0, | weight_decay=0.0, | ||||
| decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): | |||||
| decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | |||||
| super(Lamb, self).__init__(start_learning_rate, params) | super(Lamb, self).__init__(start_learning_rate, params) | ||||
| if self.is_group: | if self.is_group: | ||||
| @@ -191,8 +191,8 @@ def get_bprop_mul(self): | |||||
| mul_func = P.Mul() | mul_func = P.Mul() | ||||
| def bprop(x, y, out, dout): | def bprop(x, y, out, dout): | ||||
| bc_dx = mul_func(dout, y) | |||||
| bc_dy = mul_func(dout, x) | |||||
| bc_dx = mul_func(y, dout) | |||||
| bc_dy = mul_func(x, dout) | |||||
| return binop_grad_common(x, y, bc_dx, bc_dy) | return binop_grad_common(x, y, bc_dx, bc_dy) | ||||
| return bprop | return bprop | ||||