|
|
@@ -47,7 +47,7 @@ void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_n |
|
|
|
|
|
|
|
|
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { |
|
|
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { |
|
|
return VectorRef( |
|
|
return VectorRef( |
|
|
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})}); |
|
|
|
|
|
|
|
|
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
|