| @@ -94,6 +94,7 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int | |||
| } | |||
| void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, | |||
| const std::vector<AnfNodePtr> &base_splitv_outputs, | |||
| const std::vector<int> &size_splits_base, int split_dim, int num_split) { | |||
| SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | |||
| @@ -106,6 +107,7 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt | |||
| for (int i = 0; i < num_split; ++i) { | |||
| output_shape[split_dim] = size_splits_base[i]; | |||
| base_output_shapes_base.emplace_back(output_shape); | |||
| AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get()); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); | |||
| } | |||
| @@ -127,11 +129,14 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| // Start to divide the outputs of Split. | |||
| std::vector<int> size_splits_base; | |||
| std::vector<AnfNodePtr> base_splitv_outputs; | |||
| const auto base_split_size = divisor * small_split_size; | |||
| int nodes_num = 0; | |||
| int cur_output_index = 0; | |||
| while (num_split - cur_output_index > divisor) { | |||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); | |||
| auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); | |||
| base_splitv_outputs.push_back(tuple_getitem); | |||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem); | |||
| SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); | |||
| AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); | |||
| AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); | |||
| @@ -142,7 +147,9 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int | |||
| if (cur_output_index < num_split) { | |||
| auto last_node_num_split = num_split - cur_output_index; | |||
| if (last_node_num_split > 1) { | |||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); | |||
| auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); | |||
| base_splitv_outputs.push_back(tuple_getitem); | |||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem); | |||
| std::vector<int> size_splits_new_last(last_node_num_split, small_split_size); | |||
| SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); | |||
| // Create new output shape and new output type id for the last Splitv node | |||
| @@ -154,13 +161,15 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int | |||
| AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); | |||
| size_splits_base.emplace_back(last_node_num_split * small_split_size); | |||
| } else { | |||
| make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); | |||
| auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num); | |||
| base_splitv_outputs.push_back(tuple_getitem); | |||
| make_tuple_inputs.emplace_back(tuple_getitem); | |||
| size_splits_base.emplace_back(small_split_size); | |||
| } | |||
| nodes_num++; | |||
| } | |||
| // Set Attr and abstract for the base splitv | |||
| SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); | |||
| SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num); | |||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| return make_tuple; | |||
| } | |||