Browse Source

bugfix: confusion_softmax_grad need to be set with axis and keep_dims attr

tags/v0.2.0-alpha
huanghui 5 years ago
parent
commit
7cded1ec32
1 changed files with 22 additions and 0 deletions
  1. +22
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc

+ 22
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc View File

@@ -21,9 +21,30 @@
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h" #include "ir/primitive.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "pre_activate/common/helper.h"


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace {
void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_node) {
MS_EXCEPTION_IF_NULL(sub_anf);
MS_EXCEPTION_IF_NULL(fusion_node);
auto sub = sub_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub);
if (sub->size() != kSubInputNum) {
MS_LOG(EXCEPTION) << "Sub's size is not equal with 3";
}
auto reduce_sum_anf = sub->input(2);
MS_EXCEPTION_IF_NULL(reduce_sum_anf);
auto reduce_sum = reduce_sum_anf->cast<CNodePtr>();
if (reduce_sum == nullptr) {
MS_LOG(EXCEPTION) << "Sub's second input is not a cnode";
}
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node);
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
}
} // namespace

const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
return VectorRef( return VectorRef(
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})}); {prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
@@ -48,6 +69,7 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get());
confusion_softmax_grad->set_scope(node->scope()); confusion_softmax_grad->set_scope(node->scope());
SetAttrsForFusionNode(node, confusion_softmax_grad);
return confusion_softmax_grad; return confusion_softmax_grad;
} }
} // namespace opt } // namespace opt


Loading…
Cancel
Save