/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "host_kernels/strided_slice_kernel.h" #include #include "common/fp16_t.h" #include "common/ge_inner_error_codes.h" #include "common/math/math_util.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" namespace ge { namespace { const int32_t kNumOne = 1; const size_t kStridedSliceInputSize = 4; const size_t kStridedSliceInputIndex0 = 0; const size_t kStridedSliceInputIndex1 = 1; const size_t kStridedSliceInputIndex2 = 2; const size_t kStridedSliceInputIndex3 = 3; const int32_t kDefaultSrideSize = 1; } // namespace Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr, const std::vector &input, Attr &args) { int64_t begin_mask = 0; int64_t end_mask = 0; int64_t ellipsis_mask = 0; int64_t new_axis_mask = 0; int64_t shrink_axis_mask = 0; if (attr == nullptr) { GELOGW("input opdescptr is nullptr."); return PARAM_INVALID; } if (input.size() != kStridedSliceInputSize) { GELOGW("The number of input for strided slice must be %zu.", kStridedSliceInputSize); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_BEGIN_MASK, begin_mask)) { GELOGW("get begin_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_END_MASK, end_mask)) { GELOGW("get end_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_ELLIPSIS_MASK, ellipsis_mask)) { GELOGW("get ellipsis_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_NEW_AXIS_MASK, new_axis_mask)) { GELOGW("get new_axis_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK, shrink_axis_mask)) { GELOGW("get shrink_axis_mask attr failed."); return PARAM_INVALID; } if ((ellipsis_mask != 0) || (new_axis_mask != 0)) { GELOGW("ellipsis_mask or new_axis_mask must be 0 with optimizer."); return NOT_CHANGED; } const auto &input_desc = attr->MutableInputDesc(kStridedSliceInputIndex0); GE_CHECK_NOTNULL(input_desc); DataType data_type = input_desc->GetDataType(); if ((data_type != DT_FLOAT) && (data_type != DT_INT32)) { GELOGW( "Data type of StridedSlice OP must be float or int32." "Constant folding will not be carried out in this condition" "which might affect the time performance but not the accuracy"); } args.begin_mask = begin_mask; args.end_mask = end_mask; args.ellipsis_mask = ellipsis_mask; args.new_axis_mask = new_axis_mask; args.data_type = static_cast(data_type); args.shrink_axis_mask = shrink_axis_mask; ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex0]; ConstGeTensorPtr weight1 = input[kStridedSliceInputIndex1]; ConstGeTensorPtr weight2 = input[kStridedSliceInputIndex2]; ConstGeTensorPtr weight3 = input[kStridedSliceInputIndex3]; if (CheckWeight(weight0, weight1, weight2, weight3) != SUCCESS) { GELOGW("Check And Get Attr failed."); return PARAM_INVALID; } return SUCCESS; } Status StridedSliceKernel::CheckWeight(const ConstGeTensorPtr &weight0, const ConstGeTensorPtr &weight1, const ConstGeTensorPtr &weight2, const ConstGeTensorPtr &weight3) const { if ((weight0 == nullptr) || (weight1 == nullptr) || (weight2 == nullptr) || (weight3 == nullptr)) { GELOGW("weight is nullptr."); return PARAM_INVALID; } if (!(weight1->GetTensorDesc().GetDataType() == DT_INT32 && weight2->GetTensorDesc().GetDataType() == DT_INT32 && weight3->GetTensorDesc().GetDataType() == DT_INT32)) { GELOGE(INTERNAL_ERROR, "Data type of StridedSlice OP(begin,end,strides) must be int32."); return INTERNAL_ERROR; } // check data size_t weight0_size = weight0->GetData().size() / sizeof(int32_t); size_t weight1_size = weight1->GetData().size() / sizeof(int32_t); size_t weight2_size = weight2->GetData().size() / sizeof(int32_t); size_t weight3_size = weight3->GetData().size() / sizeof(int32_t); if ((weight0_size == 0) || (weight1_size == 0) || (weight2_size == 0) || (weight3_size == 0)) { GELOGW("Data size of inputs is 0."); return PARAM_INVALID; } // check dim size size_t weight0_dim_size = weight0->GetTensorDesc().GetShape().GetDimNum(); if (!((weight0_dim_size >= weight1_size) && (weight1_size == weight2_size) && (weight1_size == weight3_size))) { GELOGW("The sizes of begin, end and stride is not supported."); return NOT_CHANGED; } return SUCCESS; } Status StridedSliceKernel::MaskCal(const bool &begin_mask_flag, const bool &end_mask_flag, const bool &shrink_mask_flag, int32_t &begin_i, int32_t &end_i, int32_t &dim_i) const { if (shrink_mask_flag) { begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i); FMK_INT32_ADDCHECK(begin_i, kNumOne); end_i = begin_i + kNumOne; } else { if (begin_mask_flag) { begin_i = 0; } else { begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i); } if (end_mask_flag) { end_i = dim_i; } else { end_i = (end_i < 0 ? (dim_i + end_i) : end_i); } } return SUCCESS; } void StridedSliceKernel::GetOutputDims(uint32_t dims_size, const std::vector &output_dims, const Attr &args, vector &v_dims) { for (uint32_t k = 0; k < dims_size; k++) { bool shrink_mask_i = (static_cast(args.shrink_axis_mask) & (1 << k)); if (shrink_mask_i) { continue; } v_dims.push_back(output_dims[k]); } } Status StridedSliceKernel::CheckOutputDims(const std::vector &output_dims, const OpDescPtr attr) { // check dim not all less than 0 for (auto dim : output_dims) { if (dim > 0) { return SUCCESS; } } GELOGW("all output dim <=0, can't be processed. op_name : %s", attr->GetName().c_str()); return NOT_CHANGED; } Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector &input, vector &v_output) { GELOGI("StridedSliceKernel in."); Attr args; Status ret = CheckAndGetAttr(attr, input, args); if (ret != SUCCESS) { GELOGW("Check And Get Attr failed."); return NOT_CHANGED; } ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex0]; ConstGeTensorPtr weight1 = input[kStridedSliceInputIndex1]; ConstGeTensorPtr weight2 = input[kStridedSliceInputIndex2]; ConstGeTensorPtr weight3 = input[kStridedSliceInputIndex3]; const GeShape x_shape = weight0->GetTensorDesc().GetShape(); size_t dim_size = x_shape.GetDimNum(); size_t data_size = weight0->GetData().size() / sizeof(int32_t); const int32_t *begin = reinterpret_cast(weight1->GetData().data()); const int32_t *end = reinterpret_cast(weight2->GetData().data()); const int32_t *stride = reinterpret_cast(weight3->GetData().data()); if ((begin == nullptr) || (end == nullptr) || (stride == nullptr)) { GELOGW("input weight tensor is nullptr."); return NOT_CHANGED; } std::vector input_dims; std::vector begin_vec; std::vector output_dims; std::vector stride_vec; int64_t dim_final; for (size_t i = 0; i < dim_size; i++) { int32_t begin_i = begin[i]; int32_t end_i = end[i]; int32_t stride_i = stride[i]; int32_t dim_i = static_cast(x_shape.GetDim(i)); GELOGI("%d\t %d\t %d\t %d", begin_i, end_i, stride_i, dim_i); uint32_t i_temp = static_cast(i); bool begin_mask_i = (static_cast(args.begin_mask) & (1 << i_temp)); bool end_mask_i = (static_cast(args.end_mask) & (1 << i_temp)); bool shrink_mask_i = (static_cast(args.shrink_axis_mask) & (1 << i_temp)); ret = MaskCal(begin_mask_i, end_mask_i, shrink_mask_i, begin_i, end_i, dim_i); if (ret != SUCCESS) { GELOGW("MaskCal failed, because of data overflow."); return NOT_CHANGED; } if (stride_i == 0) { stride_i = kDefaultSrideSize; } else if (stride_i < 0) { stride_i = -stride_i; begin_i = x_shape.GetDim(i) - begin_i - 1; end_i = x_shape.GetDim(i) - end_i - 1; } if ((begin_i == 0) && (end_i == 0)) { dim_final = x_shape.GetDim(i); } else { dim_final = abs(end_i - begin_i) / stride_i; } output_dims.push_back(dim_final); input_dims.push_back(x_shape.GetDim(i)); begin_vec.push_back(begin_i); stride_vec.push_back(stride_i); } // Index 0 can always gets a GeTensorDesc object from any OpDescPtr. auto output_tensor_desc = attr->GetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { GELOGW("MakeShared GeTensor failed, node name %s.", attr->GetName().c_str()); return NOT_CHANGED; } void *data = reinterpret_cast(const_cast(weight0->GetData().data())); GE_CHECK_NOTNULL(data); ret = CheckOutputDims(output_dims, attr); if (ret != SUCCESS) { return ret; } ret = OpUtils::SetOutputSliceData(data, static_cast(data_size), args.data_type, input_dims, begin_vec, output_dims, output_ptr.get(), stride_vec); if (ret != SUCCESS) { GELOGW("SetOutputSliceData failed."); return NOT_CHANGED; } GeTensorDesc &t_d = output_ptr->MutableTensorDesc(); t_d.SetDataType(static_cast(args.data_type)); uint32_t final_dim_size = static_cast(output_dims.size()); vector v_dims; GetOutputDims(final_dim_size, output_dims, args, v_dims); t_d.SetShape(GeShape(v_dims)); v_output.push_back(output_ptr); GELOGI("StridedSliceKernel success."); return SUCCESS; } REGISTER_KERNEL(STRIDEDSLICE, StridedSliceKernel); } // namespace ge