Browse Source

!612 Bugfix: correct wrong pattern of confusion_softmax_grad_rule pass

Merge pull request !612 from huanghui/fix-confusion-softmax-grad-rule-pass
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
5034bc10ce
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc
  2. +1
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py

+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc View File

@@ -47,7 +47,7 @@ void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_n

const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
return VectorRef(
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})});
}

const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,


+ 1
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py View File

@@ -41,7 +41,7 @@ def test_confusion_softmax_grad_rule(tag):

@fns
def before(input0, input1):
res = mul(input0, input1)
res = mul(input1, input0)
# input axis will be convert to attr in ConstructKernelGraph step
res = reduce_sum(res, axis)
res = sub(input0, res)


Loading…
Cancel
Save