|
|
|
@@ -21,9 +21,30 @@ |
|
|
|
#include "session/anf_runtime_algorithm.h" |
|
|
|
#include "ir/primitive.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "pre_activate/common/helper.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(sub_anf); |
|
|
|
MS_EXCEPTION_IF_NULL(fusion_node); |
|
|
|
auto sub = sub_anf->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(sub); |
|
|
|
if (sub->size() != kSubInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Sub's size is not equal with 3"; |
|
|
|
} |
|
|
|
auto reduce_sum_anf = sub->input(2); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sum_anf); |
|
|
|
auto reduce_sum = reduce_sum_anf->cast<CNodePtr>(); |
|
|
|
if (reduce_sum == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Sub's second input is not a cnode"; |
|
|
|
} |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { |
|
|
|
return VectorRef( |
|
|
|
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})}); |
|
|
|
@@ -48,6 +69,7 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co |
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get()); |
|
|
|
confusion_softmax_grad->set_scope(node->scope()); |
|
|
|
SetAttrsForFusionNode(node, confusion_softmax_grad); |
|
|
|
return confusion_softmax_grad; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
|