You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strided_slice_kernel.cc 10 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/strided_slice_kernel.h"
  17. #include <memory>
  18. #include "common/fp16_t.h"
  19. #include "common/ge_inner_error_codes.h"
  20. #include "common/math/math_util.h"
  21. #include "common/op/ge_op_utils.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "host_kernels/kernel_utils.h"
  24. #include "graph/utils/type_utils.h"
  25. #include "inc/kernel_factory.h"
  26. namespace ge {
  27. namespace {
  28. const int32_t kNumOne = 1;
  29. const size_t kStridedSliceInputSize = 4;
  30. const size_t kStridedSliceInputIndex0 = 0;
  31. const size_t kStridedSliceInputIndex1 = 1;
  32. const size_t kStridedSliceInputIndex2 = 2;
  33. const size_t kStridedSliceInputIndex3 = 3;
  34. const int32_t kDefaultSrideSize = 1;
  35. } // namespace
  36. Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr, const std::vector<ConstGeTensorPtr> &input,
  37. Attr &args) {
  38. int64_t begin_mask = 0;
  39. int64_t end_mask = 0;
  40. int64_t ellipsis_mask = 0;
  41. int64_t new_axis_mask = 0;
  42. int64_t shrink_axis_mask = 0;
  43. if (attr == nullptr) {
  44. GELOGE(PARAM_INVALID, "input opdescptr is nullptr.");
  45. return PARAM_INVALID;
  46. }
  47. if (input.size() != kStridedSliceInputSize) {
  48. GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu.", kStridedSliceInputSize);
  49. return PARAM_INVALID;
  50. }
  51. if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_BEGIN_MASK, begin_mask)) {
  52. GELOGE(PARAM_INVALID, "get begin_mask attr failed.");
  53. return PARAM_INVALID;
  54. }
  55. if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_END_MASK, end_mask)) {
  56. GELOGE(PARAM_INVALID, "get end_mask attr failed.");
  57. return PARAM_INVALID;
  58. }
  59. if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_ELLIPSIS_MASK, ellipsis_mask)) {
  60. GELOGE(PARAM_INVALID, "get ellipsis_mask attr failed.");
  61. return PARAM_INVALID;
  62. }
  63. if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_NEW_AXIS_MASK, new_axis_mask)) {
  64. GELOGE(PARAM_INVALID, "get new_axis_mask attr failed.");
  65. return PARAM_INVALID;
  66. }
  67. if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK, shrink_axis_mask)) {
  68. GELOGE(PARAM_INVALID, "get shrink_axis_mask attr failed.");
  69. return PARAM_INVALID;
  70. }
  71. if ((ellipsis_mask != 0) || (new_axis_mask != 0)) {
  72. GELOGW("ellipsis_mask or new_axis_mask must be 0 with optimizer.");
  73. return NOT_CHANGED;
  74. }
  75. const auto &input_desc = attr->MutableInputDesc(kStridedSliceInputIndex0);
  76. GE_CHECK_NOTNULL(input_desc);
  77. DataType data_type = input_desc->GetDataType();
  78. if ((data_type != DT_FLOAT) && (data_type != DT_INT32)) {
  79. GELOGW(
  80. "Data type of StridedSlice OP must be float or int32."
  81. "Constant folding will not be carried out in this condition"
  82. "which might affect the time performance but not the accuracy");
  83. }
  84. args.begin_mask = begin_mask;
  85. args.end_mask = end_mask;
  86. args.ellipsis_mask = ellipsis_mask;
  87. args.new_axis_mask = new_axis_mask;
  88. args.data_type = static_cast<int64_t>(data_type);
  89. args.shrink_axis_mask = shrink_axis_mask;
  90. ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex0];
  91. ConstGeTensorPtr weight1 = input[kStridedSliceInputIndex1];
  92. ConstGeTensorPtr weight2 = input[kStridedSliceInputIndex2];
  93. ConstGeTensorPtr weight3 = input[kStridedSliceInputIndex3];
  94. if (CheckWeight(weight0, weight1, weight2, weight3) != SUCCESS) {
  95. GELOGE(PARAM_INVALID, "Check And Get Attr failed.");
  96. return PARAM_INVALID;
  97. }
  98. return SUCCESS;
  99. }
  100. Status StridedSliceKernel::CheckWeight(const ConstGeTensorPtr &weight0, const ConstGeTensorPtr &weight1,
  101. const ConstGeTensorPtr &weight2, const ConstGeTensorPtr &weight3) const {
  102. if ((weight0 == nullptr) || (weight1 == nullptr) || (weight2 == nullptr) || (weight3 == nullptr)) {
  103. GELOGW("weight is nullptr.");
  104. return PARAM_INVALID;
  105. }
  106. if (!(weight1->GetTensorDesc().GetDataType() == DT_INT32 && weight2->GetTensorDesc().GetDataType() == DT_INT32 &&
  107. weight3->GetTensorDesc().GetDataType() == DT_INT32)) {
  108. GELOGE(INTERNAL_ERROR, "Data type of StridedSlice OP(begin,end,strides) must be int32.");
  109. return INTERNAL_ERROR;
  110. }
  111. // check data
  112. size_t weight0_size = weight0->GetData().size() / sizeof(int32_t);
  113. size_t weight1_size = weight1->GetData().size() / sizeof(int32_t);
  114. size_t weight2_size = weight2->GetData().size() / sizeof(int32_t);
  115. size_t weight3_size = weight3->GetData().size() / sizeof(int32_t);
  116. if ((weight0_size == 0) || (weight1_size == 0) || (weight2_size == 0) || (weight3_size == 0)) {
  117. GELOGW("Data size of inputs is 0.");
  118. return PARAM_INVALID;
  119. }
  120. // check dim size
  121. size_t weight0_dim_size = weight0->GetTensorDesc().GetShape().GetDimNum();
  122. if (!((weight0_dim_size >= weight1_size) && (weight1_size == weight2_size) && (weight1_size == weight3_size))) {
  123. GELOGW("The sizes of begin, end and stride is not supported.");
  124. return NOT_CHANGED;
  125. }
  126. return SUCCESS;
  127. }
  128. Status StridedSliceKernel::MaskCal(const bool &begin_mask_flag, const bool &end_mask_flag, const bool &shrink_mask_flag,
  129. int32_t &begin_i, int32_t &end_i, int32_t &dim_i) const {
  130. if (shrink_mask_flag) {
  131. begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i);
  132. FMK_INT32_ADDCHECK(begin_i, kNumOne);
  133. end_i = begin_i + kNumOne;
  134. } else {
  135. if (begin_mask_flag) {
  136. begin_i = 0;
  137. } else {
  138. begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i);
  139. }
  140. if (end_mask_flag) {
  141. end_i = dim_i;
  142. } else {
  143. end_i = (end_i < 0 ? (dim_i + end_i) : end_i);
  144. }
  145. }
  146. return SUCCESS;
  147. }
  148. void StridedSliceKernel::GetOutputDims(uint32_t dims_size, const std::vector<int64_t> &output_dims, const Attr &args,
  149. vector<int64_t> &v_dims) {
  150. for (uint32_t k = 0; k < dims_size; k++) {
  151. bool shrink_mask_i = (static_cast<uint32_t>(args.shrink_axis_mask) & (1 << k));
  152. if (shrink_mask_i) {
  153. continue;
  154. }
  155. v_dims.push_back(output_dims[k]);
  156. }
  157. }
  158. Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<ge::ConstGeTensorPtr> &input,
  159. vector<ge::GeTensorPtr> &v_output) {
  160. GELOGI("StridedSliceKernel in.");
  161. Attr args;
  162. Status ret = CheckAndGetAttr(attr, input, args);
  163. if (ret != SUCCESS) {
  164. GELOGW("Check And Get Attr failed.");
  165. return NOT_CHANGED;
  166. }
  167. ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex0];
  168. ConstGeTensorPtr weight1 = input[kStridedSliceInputIndex1];
  169. ConstGeTensorPtr weight2 = input[kStridedSliceInputIndex2];
  170. ConstGeTensorPtr weight3 = input[kStridedSliceInputIndex3];
  171. const GeShape x_shape = weight0->GetTensorDesc().GetShape();
  172. size_t dim_size = x_shape.GetDimNum();
  173. size_t data_size = weight0->GetData().size() / sizeof(int32_t);
  174. const int32_t *begin = reinterpret_cast<const int32_t *>(weight1->GetData().data());
  175. const int32_t *end = reinterpret_cast<const int32_t *>(weight2->GetData().data());
  176. const int32_t *stride = reinterpret_cast<const int32_t *>(weight3->GetData().data());
  177. if ((begin == nullptr) || (end == nullptr) || (stride == nullptr)) {
  178. GELOGE(PARAM_INVALID, "input weight tensor is nullptr.");
  179. return NOT_CHANGED;
  180. }
  181. std::vector<int64_t> input_dims;
  182. std::vector<int64_t> begin_vec;
  183. std::vector<int64_t> output_dims;
  184. std::vector<int64_t> stride_vec;
  185. int64_t dim_final;
  186. for (size_t i = 0; i < dim_size; i++) {
  187. int32_t begin_i = begin[i];
  188. int32_t end_i = end[i];
  189. int32_t stride_i = stride[i];
  190. int32_t dim_i = static_cast<int32_t>(x_shape.GetDim(i));
  191. GELOGI("%d\t %d\t %d\t %d", begin_i, end_i, stride_i, dim_i);
  192. uint32_t i_temp = static_cast<uint32_t>(i);
  193. bool begin_mask_i = (static_cast<uint32_t>(args.begin_mask) & (1 << i_temp));
  194. bool end_mask_i = (static_cast<uint32_t>(args.end_mask) & (1 << i_temp));
  195. bool shrink_mask_i = (static_cast<uint32_t>(args.shrink_axis_mask) & (1 << i_temp));
  196. ret = MaskCal(begin_mask_i, end_mask_i, shrink_mask_i, begin_i, end_i, dim_i);
  197. if (ret != SUCCESS) {
  198. GELOGW("MaskCal failed, because of data overflow.");
  199. return NOT_CHANGED;
  200. }
  201. if (stride_i == 0) {
  202. stride_i = kDefaultSrideSize;
  203. } else if (stride_i < 0) {
  204. stride_i = -stride_i;
  205. begin_i = x_shape.GetDim(i) - begin_i - 1;
  206. end_i = x_shape.GetDim(i) - end_i - 1;
  207. }
  208. if ((begin_i == 0) && (end_i == 0)) {
  209. dim_final = x_shape.GetDim(i);
  210. } else {
  211. dim_final = abs(end_i - begin_i) / stride_i;
  212. }
  213. output_dims.push_back(dim_final);
  214. input_dims.push_back(x_shape.GetDim(i));
  215. begin_vec.push_back(begin_i);
  216. stride_vec.push_back(stride_i);
  217. }
  218. // Index 0 can always gets a GeTensorDesc object from any OpDescPtr.
  219. auto output_tensor_desc = attr->GetOutputDesc(0);
  220. GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
  221. if (output_ptr == nullptr) {
  222. GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", attr->GetName().c_str());
  223. return NOT_CHANGED;
  224. }
  225. void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(weight0->GetData().data()));
  226. GE_CHECK_NOTNULL(data);
  227. ret = OpUtils::SetOutputSliceData(data, static_cast<int64_t>(data_size), args.data_type, input_dims, begin_vec,
  228. output_dims, output_ptr.get(), stride_vec);
  229. if (ret != SUCCESS) {
  230. GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed.");
  231. return NOT_CHANGED;
  232. }
  233. GeTensorDesc &t_d = output_ptr->MutableTensorDesc();
  234. t_d.SetDataType(static_cast<DataType>(args.data_type));
  235. uint32_t final_dim_size = static_cast<uint32_t>(output_dims.size());
  236. vector<int64_t> v_dims;
  237. GetOutputDims(final_dim_size, output_dims, args, v_dims);
  238. t_d.SetShape(GeShape(v_dims));
  239. v_output.push_back(output_ptr);
  240. GELOGI("StridedSliceKernel success.");
  241. return SUCCESS;
  242. }
  243. REGISTER_KERNEL(STRIDEDSLICE, StridedSliceKernel);
  244. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.