|
|
|
@@ -24,40 +24,57 @@ |
|
|
|
#include "pre_activate/common/helper.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(add); |
|
|
|
|
|
|
|
for (size_t index = 1; index < add->size(); ++index) { |
|
|
|
auto input = add->input(index); |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
if (input->isa<CNode>()) { |
|
|
|
auto cnode = input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { |
|
|
|
if (!opt::IsUsedByOthers(graph, cnode)) { |
|
|
|
*mul = cnode; |
|
|
|
*mul_index = index; |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
namespace opt { |
|
|
|
const BaseRef MulAddFusion::DefinePattern() const { |
|
|
|
VarPtr mul_x_ = std::make_shared<Var>(); |
|
|
|
VarPtr mul_y_ = std::make_shared<Var>(); |
|
|
|
VarPtr add_y_ = std::make_shared<Var>(); |
|
|
|
|
|
|
|
VectorRef mul({prim::kPrimMul, mul_x_, mul_y_}); |
|
|
|
VectorRef add({prim::kPrimTensorAdd, mul, add_y_}); |
|
|
|
return add; |
|
|
|
VarPtr x = std::make_shared<Var>(); |
|
|
|
VarPtr y = std::make_shared<Var>(); |
|
|
|
VectorRef pattern({prim::kPrimTensorAdd, x, y}); |
|
|
|
return pattern; |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { |
|
|
|
if (graph == nullptr || node == nullptr || equiv == nullptr) { |
|
|
|
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { |
|
|
|
if (graph == nullptr || node == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto add = node->cast<CNodePtr>(); |
|
|
|
if (add == nullptr || add->inputs().size() != kAddInputNum) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto mul_anf = add->input(1); |
|
|
|
if (mul_anf == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto mul = mul_anf->cast<CNodePtr>(); |
|
|
|
if (mul == nullptr || mul->inputs().size() != kMulInputNum) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (IsUsedByOthers(graph, mul)) { |
|
|
|
MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse"; |
|
|
|
CNodePtr mul = nullptr; |
|
|
|
size_t mul_index = 0; |
|
|
|
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) { |
|
|
|
MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto prim = std::make_shared<Primitive>(kFusedMulAddOpName); |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul->input(1), mul->input(2), add->input(2)}; |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; |
|
|
|
for (size_t index = 1; index < mul->size(); ++index) { |
|
|
|
inputs.push_back(mul->input(index)); |
|
|
|
} |
|
|
|
inputs.push_back(add->input(add->size() - mul_index)); |
|
|
|
auto fusion_node = graph->NewCNode(inputs); |
|
|
|
fusion_node->set_scope(add->scope()); |
|
|
|
fusion_node->set_abstract(add->abstract()); |
|
|
|
|