|
|
|
@@ -73,6 +73,12 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0); |
|
|
|
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1); |
|
|
|
if (shape1 != shape2) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask}; |
|
|
|
|