| @@ -21,9 +21,30 @@ | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_node) { | |||||
| MS_EXCEPTION_IF_NULL(sub_anf); | |||||
| MS_EXCEPTION_IF_NULL(fusion_node); | |||||
| auto sub = sub_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sub); | |||||
| if (sub->size() != kSubInputNum) { | |||||
| MS_LOG(EXCEPTION) << "Sub's size is not equal with 3"; | |||||
| } | |||||
| auto reduce_sum_anf = sub->input(2); | |||||
| MS_EXCEPTION_IF_NULL(reduce_sum_anf); | |||||
| auto reduce_sum = reduce_sum_anf->cast<CNodePtr>(); | |||||
| if (reduce_sum == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Sub's second input is not a cnode"; | |||||
| } | |||||
| AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); | |||||
| AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { | const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { | ||||
| return VectorRef( | return VectorRef( | ||||
| {prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})}); | {prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})}); | ||||
| @@ -48,6 +69,7 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | ||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get()); | AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get()); | ||||
| confusion_softmax_grad->set_scope(node->scope()); | confusion_softmax_grad->set_scope(node->scope()); | ||||
| SetAttrsForFusionNode(node, confusion_softmax_grad); | |||||
| return confusion_softmax_grad; | return confusion_softmax_grad; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||