|
|
|
@@ -32,7 +32,21 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { |
|
|
|
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); |
|
|
|
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); |
|
|
|
} |
|
|
|
void AddInputToOutput(const FuncGraphPtr &func_graph, const CNodePtr &old_cnode, const AnfNodePtr &new_node, |
|
|
|
std::vector<AnfNodePtr> *new_outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(old_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(new_node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_outputs); |
|
|
|
auto node_to_output = old_cnode->input(kAccumIndex + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(node_to_output); |
|
|
|
AbstractBasePtrList abstract_list{old_cnode->abstract(), node_to_output->abstract()}; |
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); |
|
|
|
new_node->set_abstract(abstract_tuple); |
|
|
|
// Create Output |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, new_outputs); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef MomentumLossscaleFusion::DefinePattern() const { |
|
|
|
VarPtr Xs = std::make_shared<SeqVar>(); |
|
|
|
VarPtr X0 = std::make_shared<Var>(); |
|
|
|
@@ -80,15 +94,10 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph |
|
|
|
input_names_value[3] = "x1"; |
|
|
|
input_names_value.emplace_back("x2"); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); |
|
|
|
auto node_to_output = cnode->input(kAccumIndex + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(node_to_output); |
|
|
|
AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()}; |
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); |
|
|
|
new_node->set_abstract(abstract_tuple); |
|
|
|
new_node->set_scope(node->scope()); |
|
|
|
// Create Output |
|
|
|
// Create Outputs |
|
|
|
std::vector<AnfNodePtr> new_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs); |
|
|
|
AddInputToOutput(func_graph, cnode, new_node, &new_outputs); |
|
|
|
if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString(); |
|
|
|
} |
|
|
|
|