|
|
|
@@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ |
|
|
|
new_addn->set_scope(origin_addn_cnode->scope()); |
|
|
|
new_addn->set_abstract(origin_addn_cnode->abstract()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); |
|
|
|
std::vector<int> dyn_input_sizes{SizeToInt(offset)}; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); |
|
|
|
return new_addn; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
@@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN |
|
|
|
} |
|
|
|
CNodePtr new_cnode = cnode; |
|
|
|
while (origin_input_size > inputs_divisor_) { |
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode); |
|
|
|
std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))}; |
|
|
|
size_t cur_input_index = 1; |
|
|
|
// Divide the inputs of addn by 63. |
|
|
|
while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { |
|
|
|
// Divide the inputs of addn by inputs_divisor_. |
|
|
|
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { |
|
|
|
base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); |
|
|
|
cur_input_index += inputs_divisor_; |
|
|
|
} |
|
|
|
base_addn_inputs.push_back( |
|
|
|
CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); |
|
|
|
|
|
|
|
for (size_t i = cur_input_index; i <= origin_input_size; i++) { |
|
|
|
base_addn_inputs.push_back(new_cnode->input(i)); |
|
|
|
} |
|
|
|
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(base_addn); |
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode); |
|
|
|
base_addn->set_scope(new_cnode->scope()); |
|
|
|
base_addn->set_abstract(new_cnode->abstract()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); |
|
|
|
std::vector<int> dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); |
|
|
|
new_cnode = base_addn; |
|
|
|
origin_input_size = base_addn->inputs().size() - 1; |
|
|
|
} |
|
|
|
|