Merge pull request !472 from YuJianfeng/mastertags/v0.2.0-alpha
| @@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ | |||||
| new_addn->set_scope(origin_addn_cnode->scope()); | new_addn->set_scope(origin_addn_cnode->scope()); | ||||
| new_addn->set_abstract(origin_addn_cnode->abstract()); | new_addn->set_abstract(origin_addn_cnode->abstract()); | ||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); | AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); | ||||
| std::vector<int> dyn_input_sizes{SizeToInt(offset)}; | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); | |||||
| return new_addn; | return new_addn; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN | |||||
| } | } | ||||
| CNodePtr new_cnode = cnode; | CNodePtr new_cnode = cnode; | ||||
| while (origin_input_size > inputs_divisor_) { | while (origin_input_size > inputs_divisor_) { | ||||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||||
| std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))}; | std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))}; | ||||
| size_t cur_input_index = 1; | size_t cur_input_index = 1; | ||||
| // Divide the inputs of addn by 63. | |||||
| while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { | |||||
| // Divide the inputs of addn by inputs_divisor_. | |||||
| while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { | |||||
| base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); | base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); | ||||
| cur_input_index += inputs_divisor_; | cur_input_index += inputs_divisor_; | ||||
| } | } | ||||
| base_addn_inputs.push_back( | |||||
| CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); | |||||
| for (size_t i = cur_input_index; i <= origin_input_size; i++) { | |||||
| base_addn_inputs.push_back(new_cnode->input(i)); | |||||
| } | |||||
| CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); | CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); | ||||
| MS_EXCEPTION_IF_NULL(base_addn); | MS_EXCEPTION_IF_NULL(base_addn); | ||||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||||
| base_addn->set_scope(new_cnode->scope()); | base_addn->set_scope(new_cnode->scope()); | ||||
| base_addn->set_abstract(new_cnode->abstract()); | base_addn->set_abstract(new_cnode->abstract()); | ||||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); | AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); | ||||
| std::vector<int> dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); | |||||
| new_cnode = base_addn; | new_cnode = base_addn; | ||||
| origin_input_size = base_addn->inputs().size() - 1; | origin_input_size = base_addn->inputs().size() - 1; | ||||
| } | } | ||||
| @@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; | |||||
| constexpr auto kAttrSrcFormat = "src_format"; | constexpr auto kAttrSrcFormat = "src_format"; | ||||
| constexpr auto kAttrOutputUsedNum = "output_used_num"; | constexpr auto kAttrOutputUsedNum = "output_used_num"; | ||||
| constexpr auto kAttrHasBias = "has_bias"; | constexpr auto kAttrHasBias = "has_bias"; | ||||
| constexpr auto kAttrN = "N"; | |||||
| constexpr auto kAttrN = "n"; | |||||
| constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; | constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; | ||||
| // attr value | // attr value | ||||
| @@ -45,13 +45,10 @@ def test_addn_fission(tag): | |||||
| b = addn((input2, input3)) | b = addn((input2, input3)) | ||||
| c = addn((input4, input5)) | c = addn((input4, input5)) | ||||
| d = addn((input6, input7)) | d = addn((input6, input7)) | ||||
| e = addn((input8,)) | |||||
| f = addn((a, b)) | f = addn((a, b)) | ||||
| g = addn((c, d)) | g = addn((c, d)) | ||||
| h = addn((e,)) | |||||
| i = addn((f, g)) | i = addn((f, g)) | ||||
| j = addn((h,)) | |||||
| return addn((i, j)) | |||||
| return addn((i, input8)) | |||||
| @fns | @fns | ||||
| def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): | def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): | ||||
| @@ -64,14 +61,12 @@ def test_addn_fission(tag): | |||||
| def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): | def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): | ||||
| a = addn((input0, input1, input2, input3)) | a = addn((input0, input1, input2, input3)) | ||||
| b = addn((input4, input5, input6, input7)) | b = addn((input4, input5, input6, input7)) | ||||
| c = addn((input8,)) | |||||
| return addn((a, b, c)) | |||||
| return addn((a, b, input8)) | |||||
| @fns | @fns | ||||
| def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): | def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): | ||||
| a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) | a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) | ||||
| b = addn((input8,)) | |||||
| return addn((a, b)) | |||||
| return addn((a, input8)) | |||||
| @fns | @fns | ||||
| def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): | def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): | ||||