You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fisson_util.cc 20 kB

optimize the comment and log description 修改: ops/operations/_inner_ops.py 修改: ops/operations/_quant_ops.py 修改: ops/operations/array_ops.py 修改: ops/operations/comm_ops.py 修改: ops/operations/math_ops.py 修改: ops/operations/quantum_ops.py 修改: ops/operations/rl_ops.py 修改: ops/operations/sponge_ops.py 修改: ops/operations/sponge_update_ops.py 修改: train/__init__.py 修改: common/tensor.py 修改: train/serialization.py 修改: ccsrc/pipeline/jit/parse/parse.h 修改: explainer/benchmark/_attribution/metric.py 修改: ops/composite/multitype_ops/_constexpr_utils.py 修改: ops/operations/comm_ops.py 修改: RELEASE.md 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/concat_offset_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_shape_cpu_kernel.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc 修改: mindspore/ccsrc/frontend/parallel/strategy.h 修改: mindspore/common/tensor.py 修改: mindspore/core/abstract/prim_arrays.cc 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/core/ops/logical_and.h 修改: mindspore/core/ops/logical_not.h 修改: mindspore/core/ops/logical_or.h 修改: mindspore/core/ops/reduce_all.h 修改: mindspore/core/ops/reduce_any.h 修改: mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc 修改: mindspore/nn/layer/quant.py 修改: mindspore/nn/optim/sgd.py 修改: mindspore/nn/sparse/sparse.py 修改: mindspore/numpy/array_creations.py 修改: mindspore/numpy/array_ops.py 修改: mindspore/numpy/logic_ops.py 修改: mindspore/numpy/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/train/_utils.py 修改: tests/ut/python/model/test_lenet_core_after_exception.py 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ctcloss_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h 修改: mindspore/ccsrc/fl/server/server.cc 修改: mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc 修改: mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h 修改: mindspore/ccsrc/frontend/optimizer/irpass/inline.h 修改: mindspore/ccsrc/minddata/dataset/core/device_tensor.cc 修改: mindspore/ccsrc/minddata/dataset/core/tensor.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/emnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc 修改: mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc 修改: mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc 修改: mindspore/ccsrc/pipeline/jit/action.cc 修改: mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc 修改: mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc 修改: mindspore/compression/quant/quant_utils.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/dataset/engine/validators.py 修改: mindspore/lite/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.cc 修改: mindspore/lite/micro/coder/opcoders/nnacl/int8/affine_int8_coder.cc 修改: mindspore/lite/src/runtime/kernel/ascend310/src/custom_kernel.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/lite/tools/optimizer/fisson/fisson_util.cc 修改: mindspore/ops/composite/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/math_ops.py 修改: mindspore/ops/operations/other_ops.py 修改: mindspore/boost/boost_cell_wrapper.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/common/trans.cc 修改: mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/lite/src/common/log_util.h 修改: mindspore/nn/wrap/loss_scale.py 修改: mindspore/parallel/nn/moe.py 修改: tests/mindspore_test_framework/mindspore_test.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/core/ops/conv2d.cc 修改: tests/ut/python/model/test_lenet_core_after_exception.py
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/optimizer/fisson/fisson_util.h"
  17. #include <unordered_set>
  18. #include <memory>
  19. #include "base/core_ops.h"
  20. #include "src/common/utils.h"
  21. #include "ops/split_with_overlap.h"
  22. #include "tools/common/node_util.h"
  23. #include "ops/concat.h"
  24. #include "tools/optimizer/parallel/spliter.h"
  25. #include "tools/optimizer/parallel/split_strategy.h"
  26. #include "nnacl/op_base.h"
  27. #include "src/common/log_util.h"
  28. using mindspore::converter::FmkType;
  29. namespace mindspore {
  30. namespace opt {
  31. std::vector<int64_t> GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t input_h,
  32. int64_t input_w) {
  33. if (ori_conv_prim == nullptr) {
  34. MS_LOG(DEBUG) << "input Conv2DFusion is nullptr";
  35. return {};
  36. }
  37. if (ori_conv_prim->get_pad_mode() != SAME) {
  38. return ori_conv_prim->get_pad_list();
  39. }
  40. if (ori_conv_prim->get_stride().size() < kIndexW || ori_conv_prim->get_kernel_size().size() < kIndexW ||
  41. ori_conv_prim->get_dilation().size() < kIndexW) {
  42. MS_LOG(ERROR) << "Index out of range";
  43. return {};
  44. }
  45. int64_t output_h = static_cast<int64_t>(
  46. std::ceil(static_cast<float>(input_h) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexH))));
  47. int64_t output_w = static_cast<int64_t>(
  48. std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
  49. auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
  50. auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
  51. auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
  52. auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
  53. if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
  54. INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
  55. MS_LOG(ERROR) << "int mul overflow";
  56. return {};
  57. }
  58. std::vector<int64_t> new_pad_list;
  59. int64_t pad_up = 0, pad_down = 0, pad_left = 0, pad_right = 0;
  60. int64_t pad_h_all =
  61. (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
  62. int64_t pad_w_all =
  63. (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
  64. // only check pad_up and pad_down is positive
  65. // if compute overflowed, we will get abnormal it in infer_shape
  66. if (pad_h_all >= 0) {
  67. pad_up = pad_h_all / 2;
  68. pad_down = pad_h_all - pad_up;
  69. }
  70. new_pad_list.push_back(pad_up);
  71. new_pad_list.push_back(pad_down);
  72. if (pad_w_all >= 0) {
  73. pad_left = pad_w_all / 2;
  74. pad_right = pad_w_all - pad_left;
  75. }
  76. new_pad_list.push_back(pad_left);
  77. new_pad_list.push_back(pad_right);
  78. return new_pad_list;
  79. }
  80. namespace {
  81. bool CalSplitOutputShape(int64_t splited_axis_value, const SplitInfo *split_info,
  82. std::vector<int64_t> *split_axis_out_shape,
  83. std::vector<int64_t> *split_axis_reduce_out_shape) {
  84. MS_ASSERT(split_info != nullptr && split_axis_out_shape != nullptr && split_axis_reduce_out_shape != nullptr);
  85. // ori ratio
  86. int64_t split_num = split_info->out_num;
  87. int64_t split_len = 0;
  88. for (int64_t i = 0; i < split_num; i++) {
  89. split_len += split_info->size_splits[i];
  90. }
  91. if (split_len > splited_axis_value) {
  92. return false;
  93. }
  94. // out-shape after splited
  95. int64_t tmp_value = 0;
  96. MS_CHECK_TRUE_MSG(split_num > 0, false, "out_num of split_info should be greater than zero");
  97. MS_CHECK_TRUE_MSG(split_len > 0, false, "split_len should be greater than zero");
  98. for (int64_t i = 0; i < split_num - 1; i++) {
  99. if (INT_MUL_OVERFLOW_THRESHOLD(split_info->size_splits[i], splited_axis_value, INT64_MAX)) {
  100. MS_LOG(ERROR) << "int mul overflow";
  101. return false;
  102. }
  103. int64_t tmp = UP_DIV(split_info->size_splits[i] * splited_axis_value, split_len);
  104. tmp_value += tmp;
  105. split_axis_out_shape->push_back(tmp);
  106. split_axis_reduce_out_shape->push_back(tmp_value);
  107. }
  108. split_axis_out_shape->push_back(splited_axis_value - tmp_value);
  109. split_axis_reduce_out_shape->push_back(splited_axis_value);
  110. return true;
  111. }
  112. bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_shapes, const SplitInfo *split_info,
  113. const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, size_t index_node,
  114. std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
  115. std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
  116. MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
  117. split_axis_reduce_inputs_shape != nullptr);
  118. MS_ASSERT(node_in_out_shapes.size() > index_node);
  119. auto in_out_shape = node_in_out_shapes.at(index_node);
  120. MS_ASSERT(!in_out_shape.empty());
  121. auto in_shape = in_out_shape.front();
  122. if (in_shape.size() < kAxisW) {
  123. MS_LOG(DEBUG) << "out of in_shape range";
  124. return false;
  125. }
  126. int64_t input_h = in_shape.at(kAxisH);
  127. int64_t input_w = in_shape.at(kAxisW);
  128. auto new_pad_list = GetSplitPadList(ori_conv_prim, input_h, input_w);
  129. ori_conv_prim->set_pad_list(new_pad_list);
  130. int64_t split_num = split_info->out_num;
  131. int64_t tmp = 0;
  132. std::vector<int64_t> split_axis_shape;
  133. std::vector<int64_t> split_axis_reduce_shape;
  134. // iter splited_num
  135. for (int64_t index = 0; index < split_num; index++) {
  136. // shape
  137. auto stride_h = ori_conv_prim->get_stride()[kIndexH];
  138. auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
  139. if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
  140. MS_LOG(ERROR) << "int mul overflow";
  141. return false;
  142. }
  143. if (split_info->axis == CuttingStragedy::CUT_H) { // H
  144. if (index == 0) {
  145. tmp =
  146. stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
  147. } else if (index == split_num - 1) {
  148. tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
  149. ori_conv_prim->get_kernel_size()[kIndexH];
  150. } else {
  151. tmp = stride_h * split_axis_dim + ori_conv_prim->get_kernel_size()[kIndexH];
  152. }
  153. }
  154. split_axis_shape.push_back(tmp);
  155. // reduce shape
  156. auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
  157. if (split_info->axis == CuttingStragedy::CUT_H) { // H
  158. if (index == split_num - 1) {
  159. tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
  160. ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
  161. } else {
  162. tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
  163. ori_conv_prim->get_kernel_size()[kIndexH];
  164. }
  165. }
  166. split_axis_reduce_shape.push_back(tmp);
  167. }
  168. split_axis_inputs_shape->push_back(split_axis_shape);
  169. split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
  170. return true;
  171. }
  172. } // namespace
  173. bool IsConv2D(const AnfNodePtr &node) {
  174. return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
  175. }
  176. std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim) {
  177. MS_CHECK_TRUE_MSG(ori_conv_prim != nullptr, nullptr, "input Conv2DFusion is nullptr");
  178. auto new_prim = std::make_shared<ops::Conv2DFusion>();
  179. MS_CHECK_TRUE_MSG(new_prim != nullptr, nullptr, "create Conv2DFusion return nullptr");
  180. new_prim->set_pad(ori_conv_prim->get_pad());
  181. new_prim->set_in_channel(ori_conv_prim->get_in_channel());
  182. new_prim->set_out_channel(ori_conv_prim->get_out_channel());
  183. new_prim->set_dilation(ori_conv_prim->get_dilation());
  184. new_prim->set_format(ori_conv_prim->get_format());
  185. new_prim->set_group(ori_conv_prim->get_group());
  186. new_prim->set_kernel_size(ori_conv_prim->get_kernel_size());
  187. if (ori_conv_prim->get_pad_mode() == SAME) {
  188. new_prim->set_pad_mode(PAD);
  189. } else {
  190. new_prim->set_pad_mode(ori_conv_prim->get_pad_mode());
  191. }
  192. new_prim->set_stride(ori_conv_prim->get_stride());
  193. new_prim->set_activation_type(ori_conv_prim->get_activation_type());
  194. new_prim->set_pad_list(ori_conv_prim->get_pad_list());
  195. auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
  196. if (is_depth_value != nullptr) {
  197. bool is_depth_wise = GetValue<bool>(is_depth_value);
  198. new_prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
  199. }
  200. return new_prim;
  201. }
  202. bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
  203. MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
  204. MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input SplitInfo is nullptr");
  205. if (split_info->axis != CuttingStragedy::CUT_H) {
  206. return false;
  207. }
  208. auto splited_axis = split_info->axis;
  209. // need to check
  210. if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
  211. split_info->fmk_type == FmkType::kFmkTypeOnnx) { // NHWC -> NCHW
  212. splited_axis += 1;
  213. }
  214. size_t node_size = conv_nodes.size();
  215. size_t index_node = 0;
  216. std::vector<std::vector<ShapeVector>> node_in_out_shapes;
  217. while (index_node < node_size) {
  218. // [conv3, conv2, conv1] conv1->conv2->conv3
  219. auto out_node_name = conv_nodes[index_node]->fullname_with_scope();
  220. auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
  221. auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
  222. // 0-> in-shape 1->out-shape
  223. // only one in and one output
  224. MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
  225. std::vector<ShapeVector> shape_vec = {input_shapes.front(), output_shapes.front()};
  226. node_in_out_shapes.emplace_back(shape_vec);
  227. index_node++;
  228. }
  229. if (node_in_out_shapes.empty() || node_in_out_shapes.size() < (node_size - 1) || node_in_out_shapes[0].size() <= 1 ||
  230. node_in_out_shapes[0][1].size() <= static_cast<size_t>(splited_axis) ||
  231. node_in_out_shapes[node_size - 1].empty() ||
  232. node_in_out_shapes[node_size - 1][0].size() <= static_cast<size_t>(splited_axis)) {
  233. MS_LOG(ERROR) << "out of node_in_out_shapes range";
  234. return false;
  235. }
  236. int64_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
  237. int64_t final_split_axis_value = node_in_out_shapes[node_size - 1][0][splited_axis];
  238. split_info->ori_split_axis_value = final_split_axis_value;
  239. size_t split_num = split_info->size_splits.size();
  240. std::vector<int64_t> split_axis_out_shape;
  241. std::vector<int64_t> split_axis_reduce_out_shape;
  242. if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
  243. return false;
  244. }
  245. // infer in-shape after splited
  246. std::vector<std::vector<int64_t>> split_axis_inputs_shape{split_axis_out_shape};
  247. std::vector<std::vector<int64_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
  248. index_node = 0;
  249. // iter node
  250. while (index_node < node_size) {
  251. auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
  252. MS_ASSERT(conv_cnode != nullptr);
  253. auto ori_conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
  254. MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
  255. if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
  256. &split_axis_reduce_inputs_shape)) {
  257. MS_LOG(ERROR) << "CalSplitInShape failed";
  258. return false;
  259. }
  260. index_node++;
  261. }
  262. // update ratio
  263. split_info->size_splits.clear();
  264. split_info->extend_top.clear();
  265. split_info->extend_bottom.clear();
  266. int64_t top = 0;
  267. int32_t bottom = 0;
  268. split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
  269. split_info->extend_top.push_back(top);
  270. split_info->extend_bottom.push_back(bottom);
  271. for (size_t i = 1; i < split_num; i++) {
  272. auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
  273. top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
  274. auto value = split_axis_inputs_shape[node_size][i] - top;
  275. split_info->size_splits.push_back(value);
  276. split_info->extend_top.push_back(top);
  277. split_info->extend_bottom.push_back(bottom);
  278. }
  279. return true;
  280. }
  281. bool GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
  282. std::vector<AnfNodePtr> *outputs) {
  283. MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
  284. MS_CHECK_TRUE_MSG(node != nullptr, false, "input AnfNodePtr is nullptr");
  285. MS_CHECK_TRUE_MSG(outputs != nullptr, false, "input std::vector<AnfNodePtr> is nullptr");
  286. auto cnode = node->cast<CNodePtr>();
  287. MS_CHECK_TRUE_MSG(cnode != nullptr, false, "create CNode return nullptr");
  288. for (size_t i = 0; i < output_num; i++) {
  289. auto index = NewValueNode(SizeToInt(i));
  290. MS_CHECK_TRUE_MSG(index != nullptr, false, "create ValueNode return nullptr");
  291. auto temp = SizeToInt(i);
  292. auto imm = std::make_shared<Int32Imm>(temp);
  293. MS_CHECK_TRUE_MSG(imm != nullptr, false, "create Int32Imm return nullptr");
  294. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  295. MS_CHECK_TRUE_MSG(abstract_scalar != nullptr, false, "create AbstractScalar return nullptr");
  296. index->set_abstract(abstract_scalar);
  297. auto tuple_getitem_primitive = NewValueNode(prim::kPrimTupleGetItem);
  298. MS_CHECK_TRUE_MSG(tuple_getitem_primitive != nullptr, false, "create PrimTupleGetItem return nullptr");
  299. auto tuple_getitem = func_graph->NewCNode({tuple_getitem_primitive, node, index});
  300. MS_CHECK_TRUE_MSG(tuple_getitem != nullptr, false, "create CNode return nullptr");
  301. tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
  302. outputs->push_back(tuple_getitem);
  303. }
  304. return true;
  305. }
  306. AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
  307. const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
  308. const std::string &node_name) {
  309. MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "input FuncGraphPtr is nullptr");
  310. MS_CHECK_TRUE_MSG(conv_cnode != nullptr, nullptr, "input AnfNodePtr is nullptr");
  311. MS_CHECK_TRUE_MSG(split_info != nullptr, nullptr, "input SplitInfo is nullptr");
  312. auto nodes_num = static_cast<int64_t>(conv_outputs.size());
  313. if (nodes_num != split_info->out_num) {
  314. MS_LOG(ERROR) << "Conv outputs has wrong input size";
  315. return nullptr;
  316. }
  317. auto concat_prim = std::make_shared<ops::Concat>();
  318. MS_CHECK_TRUE_MSG(concat_prim != nullptr, nullptr, "create ops::Concat return nullptr");
  319. concat_prim->set_axis(split_info->axis);
  320. // the inputs of concate are from the outputs of conv
  321. auto concate_primitive = NewValueNode(concat_prim);
  322. MS_CHECK_TRUE_MSG(concate_primitive != nullptr, nullptr, "create concate_primitive return nullptr");
  323. std::vector<AnfNodePtr> concate_inputs = {concate_primitive};
  324. for (size_t i = 0; i < static_cast<size_t>(nodes_num); i++) {
  325. concate_inputs.push_back(conv_outputs[i]);
  326. }
  327. auto concate_cnode = func_graph->NewCNode(concate_inputs);
  328. MS_CHECK_TRUE_MSG(concate_cnode != nullptr, nullptr, "create concate_cnode return nullptr");
  329. concate_cnode->set_fullname_with_scope(node_name + "_Concat");
  330. concate_cnode->set_scope(conv_cnode->scope());
  331. std::vector<AnfNodePtr> outputs;
  332. if (!GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs)) {
  333. MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
  334. return nullptr;
  335. }
  336. return concate_cnode;
  337. }
  338. bool CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
  339. std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
  340. const std::string &node_name) {
  341. MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
  342. MS_CHECK_TRUE_MSG(conv_node != nullptr, false, "input conv_node is nullptr");
  343. MS_CHECK_TRUE_MSG(split_outputs != nullptr, false, "input split_outputs is nullptr");
  344. MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input split_info is nullptr");
  345. // attr of split
  346. auto split_prim = std::make_shared<ops::SplitWithOverlap>();
  347. MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::SplitWithOverlap return nullptr");
  348. split_prim->set_split_dim(split_info->axis);
  349. split_prim->set_number_split(split_info->out_num);
  350. split_prim->set_ratio(split_info->size_splits);
  351. split_prim->set_extend_top(split_info->extend_top);
  352. split_prim->set_extend_bottom(split_info->extend_bottom);
  353. auto conv_cnode = conv_node->cast<CNodePtr>();
  354. // the inputs of split is from the inputs of conv
  355. auto split_primitive = NewValueNode(split_prim);
  356. MS_CHECK_TRUE_MSG(split_primitive != nullptr, false, "create split_primitive return nullptr");
  357. std::vector<AnfNodePtr> split_inputs = {split_primitive};
  358. // this conv only has one input, which has been ensured before
  359. split_inputs.push_back(conv_cnode->input(1));
  360. auto split_cnode = func_graph->NewCNode(split_inputs);
  361. MS_CHECK_TRUE_MSG(split_cnode != nullptr, false, "create split_cnode return nullptr");
  362. split_cnode->set_fullname_with_scope(node_name + "_Split");
  363. if (split_info->out_num < 0) {
  364. MS_LOG(ERROR) << "out_num should greater then zero";
  365. return false;
  366. }
  367. // create outputs op split
  368. if (!GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs)) {
  369. MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
  370. return false;
  371. }
  372. AbstractBasePtrList ptr_list;
  373. for (int64_t i = 0; i < split_info->out_num; i++) {
  374. // set date_type same with weight
  375. auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
  376. auto type_ptr = TypeIdToType(type_id);
  377. std::vector<int64_t> shape_vector;
  378. auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
  379. MS_CHECK_TRUE_MSG(value_node != nullptr, false, "create abstract::AbstractTensor return nullptr");
  380. ptr_list.push_back(value_node);
  381. }
  382. split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
  383. return true;
  384. }
  385. bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_size, int split_dim_size) {
  386. MS_CHECK_TRUE_MSG(ratio != nullptr, false, "input ratio is nullptr");
  387. int64_t total_block_count = 0;
  388. for (size_t i = 0; i < split_size; i++) {
  389. total_block_count += ratio[i];
  390. }
  391. if (ratio_len < split_size) {
  392. MS_LOG(ERROR) << "out of ratio range";
  393. return false;
  394. }
  395. if (total_block_count < 0) {
  396. MS_LOG(ERROR) << "divide by zero";
  397. return false;
  398. }
  399. std::vector<int64_t> new_ratio(split_size);
  400. int64_t visited_block = 0;
  401. for (size_t i = 0; i < split_size - 1; i++) {
  402. visited_block += ratio[i];
  403. if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
  404. MS_LOG(ERROR) << "int mul overflow";
  405. return false;
  406. }
  407. int64_t cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
  408. new_ratio[i + 1] = cur_border;
  409. }
  410. for (size_t i = 0; i < split_size; i++) {
  411. ratio[i] = new_ratio[i];
  412. }
  413. return true;
  414. }
  415. } // namespace opt
  416. } // namespace mindspore