From: @zoloft Reviewed-by: Signed-off-by:pull/15760/MERGE
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/base/split_with_over_lap_base.h" | |||
| #include "nnacl/split_parameter.h" | |||
| #include <string.h> | |||
| #include "nnacl/errorcode.h" | |||
| int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, | |||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) { | |||
| int input_stride = split_dim_size * inner_stride * element_bytes; | |||
| for (int slice_idx = 0; slice_idx < num_split; slice_idx++) { | |||
| int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; | |||
| char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; | |||
| for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) { | |||
| (void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride)); | |||
| src_ptr += input_stride; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, | |||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) { | |||
| int input_stride = split_dim_size * inner_stride * element_bytes; | |||
| int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; | |||
| char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; | |||
| for (int i = 0; i < outer_total_dim; i++) { | |||
| (void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride); | |||
| src_ptr += input_stride; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ | |||
| #define MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/split_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, | |||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices); | |||
| int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, | |||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ | |||
| @@ -217,8 +217,9 @@ enum PrimType { | |||
| PrimType_Call = 190, | |||
| PrimType_Custom = 191, | |||
| PrimType_CumSum = 192, | |||
| PrimType_SplitWithOverlap = 193, | |||
| PrimType_MIN = PrimType_NONE, | |||
| PrimType_MAX = PrimType_CumSum + 1 | |||
| PrimType_MAX = PrimType_SplitWithOverlap + 1 | |||
| }; | |||
| void RegInfer(int prim_type, InferShape func); | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/infer/split_with_over_lap_infer.h" | |||
| #include "nnacl/infer/infer_register.h" | |||
| int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||
| OpParameter *parameter) { | |||
| #ifdef Debug | |||
| int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); | |||
| if (check_ret != NNACL_OK) { | |||
| return check_ret; | |||
| } | |||
| #endif | |||
| if (!parameter->infer_flag_) { | |||
| return NNACL_INFER_INVALID; | |||
| } | |||
| const TensorC *input = inputs[0]; | |||
| if (inputs_size < 1) { | |||
| return NNACL_ERR; | |||
| } | |||
| if (outputs_size == 0) { | |||
| return NNACL_ERR; | |||
| } | |||
| SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; | |||
| int number_split = param->num_split_; | |||
| if (outputs_size != number_split) { | |||
| return NNACL_ERR; | |||
| } | |||
| int stride = param->stride_; | |||
| int pad_top = param->pad_top_; | |||
| int split_dim = param->split_dim_; | |||
| int ratio[SPLIT_MAX_SLICE_NUM]; | |||
| int extend_top[SPLIT_MAX_SLICE_NUM]; | |||
| int extend_bottom[SPLIT_MAX_SLICE_NUM]; | |||
| for (int i = 0; i < number_split; ++i) { | |||
| ratio[i] = param->ratio_[i]; | |||
| extend_top[i] = param->extend_top_[i]; | |||
| extend_bottom[i] = param->extend_bottom_[i]; | |||
| } | |||
| const int *input_shape = input->shape_; | |||
| int split_dim_size = input_shape[split_dim]; | |||
| int total_block_count = 0; | |||
| for (int i = 0; i < number_split; i++) { | |||
| total_block_count += ratio[i]; | |||
| } | |||
| int borders[MAX_SHAPE_SIZE]; | |||
| borders[0] = 0; | |||
| int visited_block = 0; | |||
| for (int i = 0; i < number_split - 1; i++) { | |||
| visited_block += ratio[i]; | |||
| int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); | |||
| if (stride != 0) { | |||
| // make sure border align with stride | |||
| cur_border = UP_ROUND(cur_border + pad_top, stride); | |||
| borders[i + 1] = cur_border - pad_top; | |||
| } else { | |||
| borders[i + 1] = cur_border; | |||
| } | |||
| } | |||
| borders[number_split - 1] = split_dim_size; | |||
| for (int i = 0; i < number_split; ++i) { | |||
| int output_shape[MAX_SHAPE_SIZE]; | |||
| for (int dim = 0; dim < input->shape_size_; dim++) { | |||
| if (dim == split_dim) { | |||
| int splited_size = borders[i + 1] - borders[i]; | |||
| splited_size += extend_top[i] + extend_bottom[i]; | |||
| output_shape[dim] = splited_size; | |||
| } else { | |||
| output_shape[dim] = input_shape[dim]; | |||
| } | |||
| } | |||
| SetShapeArray(outputs[i], output_shape, input->shape_size_); | |||
| SetDataTypeFormat(outputs[i], input); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape) | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H | |||
| #define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H | |||
| #include "nnacl/infer/common_infer.h" | |||
| #include "nnacl/split_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||
| OpParameter *parameter); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/op_base.h" | |||
| #define SPLIT_STRIDES_SIZE 32 | |||
| #define SPLIT_MAX_SLICE_NUM 10 | |||
| typedef struct SplitQuantArg { | |||
| QuantArg in_args_; | |||
| @@ -44,4 +45,15 @@ typedef struct SplitParameter { | |||
| int split_count_; | |||
| } SplitParameter; | |||
| typedef struct SplitWithOverlapParameter { | |||
| OpParameter op_parameter_; | |||
| int num_split_; | |||
| int split_dim_; | |||
| int stride_; | |||
| int pad_top_; | |||
| int ratio_[SPLIT_MAX_SLICE_NUM]; | |||
| int extend_top_[SPLIT_MAX_SLICE_NUM]; | |||
| int extend_bottom_[SPLIT_MAX_SLICE_NUM]; | |||
| } SplitWithOverlapParameter; | |||
| #endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_ | |||
| @@ -234,6 +234,13 @@ constexpr auto kSideEffectIO = "side_effect_io"; | |||
| constexpr auto kDeviceType = "device_type"; | |||
| constexpr auto kExclusive = "exclusive"; | |||
| constexpr auto kReverse = "reverse"; | |||
| constexpr auto kSplitStride = "split_stride"; | |||
| constexpr auto kExtendTop = "extend_top"; | |||
| constexpr auto kExtendBottom = "extend_bottom"; | |||
| constexpr auto kNumberSplit = "number_split"; | |||
| constexpr auto kSplitDim = "split_dim"; | |||
| constexpr auto kPadTop = "pad_top"; | |||
| constexpr auto kTransFormat = "trans_format"; | |||
| const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, | |||
| kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/split_with_overlap.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void SplitWithOverlap::Init(int64_t number_split, const std::vector<int64_t> &ratio, | |||
| const std::vector<int64_t> &extend_top, const std::vector<int64_t> &extend_bottom, | |||
| int64_t split_dim, int64_t stride, int64_t pad_top, bool trans_format) { | |||
| this->set_number_split(number_split); | |||
| this->set_ratio(ratio); | |||
| this->set_extend_top(extend_top); | |||
| this->set_extend_bottom(extend_bottom); | |||
| this->set_split_dim(split_dim); | |||
| this->set_stride(stride); | |||
| this->set_pad_top(pad_top); | |||
| this->set_trans_format(trans_format); | |||
| } | |||
| void SplitWithOverlap::set_ratio(const std::vector<int64_t> &ratio) { this->AddAttr(kRatio, MakeValue(ratio)); } | |||
| void SplitWithOverlap::set_extend_top(const std::vector<int64_t> &extend_top) { | |||
| this->AddAttr(kExtendTop, MakeValue(extend_top)); | |||
| } | |||
| void SplitWithOverlap::set_extend_bottom(const std::vector<int64_t> &extend_bottom) { | |||
| this->AddAttr(kExtendBottom, MakeValue(extend_bottom)); | |||
| } | |||
| void SplitWithOverlap::set_number_split(int64_t number_split) { this->AddAttr(kNumberSplit, MakeValue(number_split)); } | |||
| void SplitWithOverlap::set_split_dim(int64_t split_dim) { this->AddAttr(kSplitDim, MakeValue(split_dim)); } | |||
| void SplitWithOverlap::set_stride(int64_t stride) { this->AddAttr(kSplitStride, MakeValue(stride)); } | |||
| void SplitWithOverlap::set_pad_top(int64_t pad_top) { this->AddAttr(kPadTop, MakeValue(pad_top)); } | |||
| void SplitWithOverlap::set_trans_format(bool trans_format) { this->AddAttr(kTransFormat, MakeValue(trans_format)); } | |||
| std::vector<int64_t> SplitWithOverlap::get_ratio() const { | |||
| auto value_ptr = GetAttr(kRatio); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> SplitWithOverlap::get_extend_top() const { | |||
| auto value_ptr = GetAttr(kExtendTop); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> SplitWithOverlap::get_extend_bottom() const { | |||
| auto value_ptr = GetAttr(kExtendBottom); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| int64_t SplitWithOverlap::get_number_split() const { | |||
| auto value_ptr = GetAttr(kNumberSplit); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t SplitWithOverlap::get_split_dim() const { | |||
| auto value_ptr = GetAttr(kSplitDim); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t SplitWithOverlap::get_stride() const { | |||
| auto value_ptr = GetAttr(kSplitStride); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t SplitWithOverlap::get_pad_top() const { | |||
| auto value_ptr = GetAttr(kPadTop); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| bool SplitWithOverlap::get_trans_format() const { | |||
| auto value_ptr = GetAttr(kTransFormat); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameSplitWithOverlap, SplitWithOverlap); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ | |||
| #define MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameSplitWithOverlap = "SplitWithOverlap"; | |||
| class SplitWithOverlap : public PrimitiveC { | |||
| public: | |||
| SplitWithOverlap() : PrimitiveC(kNameSplitWithOverlap) {} | |||
| ~SplitWithOverlap() = default; | |||
| MS_DECLARE_PARENT(SplitWithOverlap, PrimitiveC); | |||
| void Init(int64_t number_split, const std::vector<int64_t> &ratio, const std::vector<int64_t> &extend_top, | |||
| const std::vector<int64_t> &extend_bottom, int64_t split_dim, int64_t stride, int64_t pad_top, | |||
| bool trans_format); | |||
| void set_ratio(const std::vector<int64_t> &ratio); | |||
| void set_extend_top(const std::vector<int64_t> &extend_top); | |||
| void set_extend_bottom(const std::vector<int64_t> &extend_bottom); | |||
| void set_number_split(int64_t number_split); | |||
| void set_split_dim(int64_t split_dim); | |||
| void set_stride(int64_t stride); | |||
| void set_pad_top(int64_t pad_top); | |||
| void set_trans_format(bool trans_format); | |||
| std::vector<int64_t> get_ratio() const; | |||
| std::vector<int64_t> get_extend_top() const; | |||
| std::vector<int64_t> get_extend_bottom() const; | |||
| int64_t get_number_split() const; | |||
| int64_t get_split_dim() const; | |||
| int64_t get_stride() const; | |||
| int64_t get_pad_top() const; | |||
| bool get_trans_format() const; | |||
| }; | |||
| AbstractBasePtr SplitWithOverlapInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimSplitWithOverlap = std::shared_ptr<SplitWithOverlap>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ | |||
| @@ -210,6 +210,7 @@ union PrimitiveType { | |||
| Call, | |||
| Custom, | |||
| CumSum, | |||
| SplitWithOverlap, | |||
| } | |||
| table Abs { | |||
| @@ -1115,3 +1116,14 @@ table Custom { | |||
| type: string; | |||
| attr: [Attribute]; | |||
| } | |||
| table SplitWithOverlap { | |||
| number_split: long; | |||
| ratio: [long]; | |||
| extend_top: [long]; | |||
| extend_bottom: [long]; | |||
| split_dim: long; | |||
| stride: long; | |||
| pad_top: long; | |||
| trans_format: bool = false; | |||
| } | |||
| @@ -209,6 +209,7 @@ OP_TYPE(LogSoftmax) | |||
| OP_TYPE(Call) | |||
| OP_TYPE(Custom) | |||
| OP_TYPE(CumSum) | |||
| OP_TYPE(SplitWithOverlap) | |||
| OP_TYPE_DEF_END(PrimitiveType) | |||
| OP_SCHEMA_DEF(Abs) | |||
| @@ -1114,3 +1115,14 @@ OP_SCHEMA_DEF_ONLY(Custom) | |||
| OP_ATTR_ONLY(type, string) | |||
| OP_ATTR_ONLY(attr, [Attribute]) | |||
| OP_SCHEMA_DEF_ONLY_END(Custom) | |||
| OP_SCHEMA_DEF(SplitWithOverlap) | |||
| OP_ATTR(number_split, long) | |||
| OP_ATTR(ratio, [long]) | |||
| OP_ATTR(extend_top, [long]) | |||
| OP_ATTR(extend_bottom, [long]) | |||
| OP_ATTR(split_dim, long) | |||
| OP_ATTR(stride, long) | |||
| OP_ATTR(pad_top, long) | |||
| OP_ATTR_WITH_VALUE(trans_format, bool, false) | |||
| OP_SCHEMA_DEF_END(SplitWithOverlap) | |||
| @@ -237,6 +237,7 @@ | |||
| #include "ops/log_softmax.h" | |||
| #include "ops/call.h" | |||
| #include "ops/cumsum.h" | |||
| #include "ops/split_with_overlap.h" | |||
| #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ | |||
| namespace mindspore::lite::ops { \ | |||
| @@ -448,5 +449,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Splice); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(Call); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(CumSum); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(SplitWithOverlap); | |||
| #endif | |||
| #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "nnacl/split_parameter.h" | |||
| using mindspore::schema::PrimitiveType_SplitWithOverlap; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { | |||
| auto *split_with_over_lap_param = | |||
| reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter))); | |||
| if (split_with_over_lap_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter)); | |||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||
| auto value = primitive->value_as_SplitWithOverlap(); | |||
| split_with_over_lap_param->op_parameter_.type_ = primitive->value_type(); | |||
| auto ratio = value->ratio(); | |||
| if (ratio->size() > SPLIT_MAX_SLICE_NUM) { | |||
| MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM | |||
| << " slices"; | |||
| delete split_with_over_lap_param; | |||
| return nullptr; | |||
| } | |||
| split_with_over_lap_param->num_split_ = static_cast<int>(ratio->size()); | |||
| split_with_over_lap_param->split_dim_ = value->split_dim(); | |||
| auto extend_top = value->extend_top(); | |||
| auto extend_bottom = value->extend_bottom(); | |||
| if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) { | |||
| MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; | |||
| delete split_with_over_lap_param; | |||
| return nullptr; | |||
| } | |||
| for (size_t i = 0; i < ratio->size(); ++i) { | |||
| split_with_over_lap_param->ratio_[i] = (*ratio)[i]; | |||
| split_with_over_lap_param->extend_top_[i] = (*extend_top)[i]; | |||
| split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i]; | |||
| } | |||
| split_with_over_lap_param->stride_ = value->stride(); | |||
| split_with_over_lap_param->pad_top_ = value->pad_top(); | |||
| return reinterpret_cast<OpParameter *>(split_with_over_lap_param); | |||
| } | |||
| REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR) | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/base/split_with_over_lap_base.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_SplitWithOverlap; | |||
| namespace mindspore::kernel { | |||
| void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param, | |||
| const std::vector<int> &shape) { | |||
| int total_block_count = 0; | |||
| for (auto i = 0; i < param->num_split_; i++) { | |||
| total_block_count += param->ratio_[i]; | |||
| } | |||
| auto split_dim_size = shape[param->split_dim_]; | |||
| std::vector<int> borders; | |||
| borders.emplace_back(0); | |||
| int visited_block = 0; | |||
| for (auto i = 0; i < param->num_split_ - 1; i++) { | |||
| visited_block += param->ratio_[i]; | |||
| auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); | |||
| if (param->stride_ != 0) { | |||
| // make sure border align with stride | |||
| cur_border = UP_ROUND(cur_border + param->pad_top_, param->stride_); | |||
| borders.emplace_back(cur_border - param->pad_top_); | |||
| } else { | |||
| borders.emplace_back(cur_border); | |||
| } | |||
| } | |||
| borders.emplace_back(split_dim_size); | |||
| for (auto i = 0; i < param->num_split_; i++) { | |||
| start_indices_.emplace_back(borders[i]); | |||
| end_indices_.emplace_back(borders[i + 1]); | |||
| // overlap: calibrate start_indices and end_indices by adding extends | |||
| start_indices_[i] -= param->extend_top_[i]; | |||
| end_indices_[i] += param->extend_bottom_[i]; | |||
| } | |||
| } | |||
| int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; } | |||
| int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; } | |||
| int SplitWithOverlapBaseCPUKernel::Split(int task_id) { | |||
| DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_, | |||
| inner_stride_, start_indices_.data(), end_indices_.data()); | |||
| return RET_OK; | |||
| } | |||
| int SplitWithOverlapRun(void *cdata, int task_id) { | |||
| auto g_kernel = reinterpret_cast<SplitWithOverlapBaseCPUKernel *>(cdata); | |||
| auto ret = g_kernel->Split(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SplitWithOverlapRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SplitWithOverlapBaseCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| auto in_tensor = in_tensors_.front(); | |||
| input_ptr_ = reinterpret_cast<char *>(in_tensor->data_c()); | |||
| auto input_shape = in_tensor->shape(); | |||
| start_indices_.clear(); | |||
| end_indices_.clear(); | |||
| output_ptr_.clear(); | |||
| for (int i = 0; i < param->num_split_; i++) { | |||
| output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c())); | |||
| } | |||
| CalculateSplitedShapes(param, input_shape); | |||
| outer_total_dim_ = 1; | |||
| inner_stride_ = 1; | |||
| split_dim_size_ = input_shape[param->split_dim_]; | |||
| element_bytes_ = in_tensor->Size(); | |||
| for (auto i = 0; i < param->split_dim_; i++) { | |||
| outer_total_dim_ *= input_shape[i]; | |||
| } | |||
| for (int i = static_cast<int>(input_shape.size()) - 1; i > param->split_dim_; i--) { | |||
| inner_stride_ *= input_shape[i]; | |||
| } | |||
| auto ret = ParallelLaunch(static_cast<const lite::InnerContext *>(this->context_)->thread_pool_, SplitWithOverlapRun, | |||
| this, param->num_split_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ | |||
| #include <vector> | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/split_parameter.h" | |||
| #include "nnacl/base/split_with_over_lap_base.h" | |||
| namespace mindspore::kernel { | |||
| class SplitWithOverlapBaseCPUKernel : public LiteKernel { | |||
| public: | |||
| SplitWithOverlapBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||
| : LiteKernel(parameter, inputs, outputs, ctx) { | |||
| param = reinterpret_cast<SplitWithOverlapParameter *>(op_parameter_); | |||
| } | |||
| ~SplitWithOverlapBaseCPUKernel() override = default; | |||
| void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector<int> &shape); | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Split(int task_id); | |||
| protected: | |||
| // range: [start, end) | |||
| std::vector<int> start_indices_; | |||
| std::vector<int> end_indices_; | |||
| int outer_total_dim_{0}; | |||
| int inner_stride_{0}; | |||
| int element_bytes_{0}; | |||
| int split_dim_size_{0}; | |||
| SplitWithOverlapParameter *param = nullptr; | |||
| char *input_ptr_{nullptr}; | |||
| std::vector<char *> output_ptr_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ | |||