From: @zoloft Reviewed-by: Signed-off-by:pull/15760/MERGE
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/base/split_with_over_lap_base.h" | |||||
| #include "nnacl/split_parameter.h" | |||||
| #include <string.h> | |||||
| #include "nnacl/errorcode.h" | |||||
| int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, | |||||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) { | |||||
| int input_stride = split_dim_size * inner_stride * element_bytes; | |||||
| for (int slice_idx = 0; slice_idx < num_split; slice_idx++) { | |||||
| int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; | |||||
| char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; | |||||
| for (int out_idx = 0; out_idx < outer_total_dim; out_idx++) { | |||||
| (void)(memcpy(out_data[slice_idx] + out_idx * out_stride, src_ptr, out_stride)); | |||||
| src_ptr += input_stride; | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, | |||||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices) { | |||||
| int input_stride = split_dim_size * inner_stride * element_bytes; | |||||
| int out_stride = (end_indices[slice_idx] - start_indices[slice_idx]) * inner_stride * element_bytes; | |||||
| char *src_ptr = in_data + start_indices[slice_idx] * inner_stride * element_bytes; | |||||
| for (int i = 0; i < outer_total_dim; i++) { | |||||
| (void)memcpy(out_data[slice_idx] + i * out_stride, src_ptr, out_stride); | |||||
| src_ptr += input_stride; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ | |||||
| #define MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/split_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int DoSplitWithOverlap(char *in_data, char **out_data, int num_split, int split_dim_size, int element_bytes, | |||||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices); | |||||
| int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, int split_dim_size, int element_bytes, | |||||
| int outer_total_dim, int inner_stride, int *start_indices, int *end_indices); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ | |||||
| @@ -217,8 +217,9 @@ enum PrimType { | |||||
| PrimType_Call = 190, | PrimType_Call = 190, | ||||
| PrimType_Custom = 191, | PrimType_Custom = 191, | ||||
| PrimType_CumSum = 192, | PrimType_CumSum = 192, | ||||
| PrimType_SplitWithOverlap = 193, | |||||
| PrimType_MIN = PrimType_NONE, | PrimType_MIN = PrimType_NONE, | ||||
| PrimType_MAX = PrimType_CumSum + 1 | |||||
| PrimType_MAX = PrimType_SplitWithOverlap + 1 | |||||
| }; | }; | ||||
| void RegInfer(int prim_type, InferShape func); | void RegInfer(int prim_type, InferShape func); | ||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/infer/split_with_over_lap_infer.h" | |||||
| #include "nnacl/infer/infer_register.h" | |||||
| int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||||
| OpParameter *parameter) { | |||||
| #ifdef Debug | |||||
| int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); | |||||
| if (check_ret != NNACL_OK) { | |||||
| return check_ret; | |||||
| } | |||||
| #endif | |||||
| if (!parameter->infer_flag_) { | |||||
| return NNACL_INFER_INVALID; | |||||
| } | |||||
| const TensorC *input = inputs[0]; | |||||
| if (inputs_size < 1) { | |||||
| return NNACL_ERR; | |||||
| } | |||||
| if (outputs_size == 0) { | |||||
| return NNACL_ERR; | |||||
| } | |||||
| SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; | |||||
| int number_split = param->num_split_; | |||||
| if (outputs_size != number_split) { | |||||
| return NNACL_ERR; | |||||
| } | |||||
| int stride = param->stride_; | |||||
| int pad_top = param->pad_top_; | |||||
| int split_dim = param->split_dim_; | |||||
| int ratio[SPLIT_MAX_SLICE_NUM]; | |||||
| int extend_top[SPLIT_MAX_SLICE_NUM]; | |||||
| int extend_bottom[SPLIT_MAX_SLICE_NUM]; | |||||
| for (int i = 0; i < number_split; ++i) { | |||||
| ratio[i] = param->ratio_[i]; | |||||
| extend_top[i] = param->extend_top_[i]; | |||||
| extend_bottom[i] = param->extend_bottom_[i]; | |||||
| } | |||||
| const int *input_shape = input->shape_; | |||||
| int split_dim_size = input_shape[split_dim]; | |||||
| int total_block_count = 0; | |||||
| for (int i = 0; i < number_split; i++) { | |||||
| total_block_count += ratio[i]; | |||||
| } | |||||
| int borders[MAX_SHAPE_SIZE]; | |||||
| borders[0] = 0; | |||||
| int visited_block = 0; | |||||
| for (int i = 0; i < number_split - 1; i++) { | |||||
| visited_block += ratio[i]; | |||||
| int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); | |||||
| if (stride != 0) { | |||||
| // make sure border align with stride | |||||
| cur_border = UP_ROUND(cur_border + pad_top, stride); | |||||
| borders[i + 1] = cur_border - pad_top; | |||||
| } else { | |||||
| borders[i + 1] = cur_border; | |||||
| } | |||||
| } | |||||
| borders[number_split - 1] = split_dim_size; | |||||
| for (int i = 0; i < number_split; ++i) { | |||||
| int output_shape[MAX_SHAPE_SIZE]; | |||||
| for (int dim = 0; dim < input->shape_size_; dim++) { | |||||
| if (dim == split_dim) { | |||||
| int splited_size = borders[i + 1] - borders[i]; | |||||
| splited_size += extend_top[i] + extend_bottom[i]; | |||||
| output_shape[dim] = splited_size; | |||||
| } else { | |||||
| output_shape[dim] = input_shape[dim]; | |||||
| } | |||||
| } | |||||
| SetShapeArray(outputs[i], output_shape, input->shape_size_); | |||||
| SetDataTypeFormat(outputs[i], input); | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape) | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H | |||||
| #define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H | |||||
| #include "nnacl/infer/common_infer.h" | |||||
| #include "nnacl/split_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||||
| OpParameter *parameter); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H | |||||
| @@ -20,6 +20,7 @@ | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #define SPLIT_STRIDES_SIZE 32 | #define SPLIT_STRIDES_SIZE 32 | ||||
| #define SPLIT_MAX_SLICE_NUM 10 | |||||
| typedef struct SplitQuantArg { | typedef struct SplitQuantArg { | ||||
| QuantArg in_args_; | QuantArg in_args_; | ||||
| @@ -44,4 +45,15 @@ typedef struct SplitParameter { | |||||
| int split_count_; | int split_count_; | ||||
| } SplitParameter; | } SplitParameter; | ||||
| typedef struct SplitWithOverlapParameter { | |||||
| OpParameter op_parameter_; | |||||
| int num_split_; | |||||
| int split_dim_; | |||||
| int stride_; | |||||
| int pad_top_; | |||||
| int ratio_[SPLIT_MAX_SLICE_NUM]; | |||||
| int extend_top_[SPLIT_MAX_SLICE_NUM]; | |||||
| int extend_bottom_[SPLIT_MAX_SLICE_NUM]; | |||||
| } SplitWithOverlapParameter; | |||||
| #endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_ | #endif // MINDSPORE_NNACL_SPLIT_PARAMETER_H_ | ||||
| @@ -234,6 +234,13 @@ constexpr auto kSideEffectIO = "side_effect_io"; | |||||
| constexpr auto kDeviceType = "device_type"; | constexpr auto kDeviceType = "device_type"; | ||||
| constexpr auto kExclusive = "exclusive"; | constexpr auto kExclusive = "exclusive"; | ||||
| constexpr auto kReverse = "reverse"; | constexpr auto kReverse = "reverse"; | ||||
| constexpr auto kSplitStride = "split_stride"; | |||||
| constexpr auto kExtendTop = "extend_top"; | |||||
| constexpr auto kExtendBottom = "extend_bottom"; | |||||
| constexpr auto kNumberSplit = "number_split"; | |||||
| constexpr auto kSplitDim = "split_dim"; | |||||
| constexpr auto kPadTop = "pad_top"; | |||||
| constexpr auto kTransFormat = "trans_format"; | |||||
| const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, | const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, | ||||
| kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | ||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/split_with_overlap.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void SplitWithOverlap::Init(int64_t number_split, const std::vector<int64_t> &ratio, | |||||
| const std::vector<int64_t> &extend_top, const std::vector<int64_t> &extend_bottom, | |||||
| int64_t split_dim, int64_t stride, int64_t pad_top, bool trans_format) { | |||||
| this->set_number_split(number_split); | |||||
| this->set_ratio(ratio); | |||||
| this->set_extend_top(extend_top); | |||||
| this->set_extend_bottom(extend_bottom); | |||||
| this->set_split_dim(split_dim); | |||||
| this->set_stride(stride); | |||||
| this->set_pad_top(pad_top); | |||||
| this->set_trans_format(trans_format); | |||||
| } | |||||
| void SplitWithOverlap::set_ratio(const std::vector<int64_t> &ratio) { this->AddAttr(kRatio, MakeValue(ratio)); } | |||||
| void SplitWithOverlap::set_extend_top(const std::vector<int64_t> &extend_top) { | |||||
| this->AddAttr(kExtendTop, MakeValue(extend_top)); | |||||
| } | |||||
| void SplitWithOverlap::set_extend_bottom(const std::vector<int64_t> &extend_bottom) { | |||||
| this->AddAttr(kExtendBottom, MakeValue(extend_bottom)); | |||||
| } | |||||
| void SplitWithOverlap::set_number_split(int64_t number_split) { this->AddAttr(kNumberSplit, MakeValue(number_split)); } | |||||
| void SplitWithOverlap::set_split_dim(int64_t split_dim) { this->AddAttr(kSplitDim, MakeValue(split_dim)); } | |||||
| void SplitWithOverlap::set_stride(int64_t stride) { this->AddAttr(kSplitStride, MakeValue(stride)); } | |||||
| void SplitWithOverlap::set_pad_top(int64_t pad_top) { this->AddAttr(kPadTop, MakeValue(pad_top)); } | |||||
| void SplitWithOverlap::set_trans_format(bool trans_format) { this->AddAttr(kTransFormat, MakeValue(trans_format)); } | |||||
| std::vector<int64_t> SplitWithOverlap::get_ratio() const { | |||||
| auto value_ptr = GetAttr(kRatio); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> SplitWithOverlap::get_extend_top() const { | |||||
| auto value_ptr = GetAttr(kExtendTop); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> SplitWithOverlap::get_extend_bottom() const { | |||||
| auto value_ptr = GetAttr(kExtendBottom); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| int64_t SplitWithOverlap::get_number_split() const { | |||||
| auto value_ptr = GetAttr(kNumberSplit); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t SplitWithOverlap::get_split_dim() const { | |||||
| auto value_ptr = GetAttr(kSplitDim); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t SplitWithOverlap::get_stride() const { | |||||
| auto value_ptr = GetAttr(kSplitStride); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t SplitWithOverlap::get_pad_top() const { | |||||
| auto value_ptr = GetAttr(kPadTop); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| bool SplitWithOverlap::get_trans_format() const { | |||||
| auto value_ptr = GetAttr(kTransFormat); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameSplitWithOverlap, SplitWithOverlap); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ | |||||
| #define MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameSplitWithOverlap = "SplitWithOverlap"; | |||||
| class SplitWithOverlap : public PrimitiveC { | |||||
| public: | |||||
| SplitWithOverlap() : PrimitiveC(kNameSplitWithOverlap) {} | |||||
| ~SplitWithOverlap() = default; | |||||
| MS_DECLARE_PARENT(SplitWithOverlap, PrimitiveC); | |||||
| void Init(int64_t number_split, const std::vector<int64_t> &ratio, const std::vector<int64_t> &extend_top, | |||||
| const std::vector<int64_t> &extend_bottom, int64_t split_dim, int64_t stride, int64_t pad_top, | |||||
| bool trans_format); | |||||
| void set_ratio(const std::vector<int64_t> &ratio); | |||||
| void set_extend_top(const std::vector<int64_t> &extend_top); | |||||
| void set_extend_bottom(const std::vector<int64_t> &extend_bottom); | |||||
| void set_number_split(int64_t number_split); | |||||
| void set_split_dim(int64_t split_dim); | |||||
| void set_stride(int64_t stride); | |||||
| void set_pad_top(int64_t pad_top); | |||||
| void set_trans_format(bool trans_format); | |||||
| std::vector<int64_t> get_ratio() const; | |||||
| std::vector<int64_t> get_extend_top() const; | |||||
| std::vector<int64_t> get_extend_bottom() const; | |||||
| int64_t get_number_split() const; | |||||
| int64_t get_split_dim() const; | |||||
| int64_t get_stride() const; | |||||
| int64_t get_pad_top() const; | |||||
| bool get_trans_format() const; | |||||
| }; | |||||
| AbstractBasePtr SplitWithOverlapInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimSplitWithOverlap = std::shared_ptr<SplitWithOverlap>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_SPLIT_WITH_OVERLAP_H_ | |||||
| @@ -210,6 +210,7 @@ union PrimitiveType { | |||||
| Call, | Call, | ||||
| Custom, | Custom, | ||||
| CumSum, | CumSum, | ||||
| SplitWithOverlap, | |||||
| } | } | ||||
| table Abs { | table Abs { | ||||
| @@ -1115,3 +1116,14 @@ table Custom { | |||||
| type: string; | type: string; | ||||
| attr: [Attribute]; | attr: [Attribute]; | ||||
| } | } | ||||
| table SplitWithOverlap { | |||||
| number_split: long; | |||||
| ratio: [long]; | |||||
| extend_top: [long]; | |||||
| extend_bottom: [long]; | |||||
| split_dim: long; | |||||
| stride: long; | |||||
| pad_top: long; | |||||
| trans_format: bool = false; | |||||
| } | |||||
| @@ -209,6 +209,7 @@ OP_TYPE(LogSoftmax) | |||||
| OP_TYPE(Call) | OP_TYPE(Call) | ||||
| OP_TYPE(Custom) | OP_TYPE(Custom) | ||||
| OP_TYPE(CumSum) | OP_TYPE(CumSum) | ||||
| OP_TYPE(SplitWithOverlap) | |||||
| OP_TYPE_DEF_END(PrimitiveType) | OP_TYPE_DEF_END(PrimitiveType) | ||||
| OP_SCHEMA_DEF(Abs) | OP_SCHEMA_DEF(Abs) | ||||
| @@ -1114,3 +1115,14 @@ OP_SCHEMA_DEF_ONLY(Custom) | |||||
| OP_ATTR_ONLY(type, string) | OP_ATTR_ONLY(type, string) | ||||
| OP_ATTR_ONLY(attr, [Attribute]) | OP_ATTR_ONLY(attr, [Attribute]) | ||||
| OP_SCHEMA_DEF_ONLY_END(Custom) | OP_SCHEMA_DEF_ONLY_END(Custom) | ||||
| OP_SCHEMA_DEF(SplitWithOverlap) | |||||
| OP_ATTR(number_split, long) | |||||
| OP_ATTR(ratio, [long]) | |||||
| OP_ATTR(extend_top, [long]) | |||||
| OP_ATTR(extend_bottom, [long]) | |||||
| OP_ATTR(split_dim, long) | |||||
| OP_ATTR(stride, long) | |||||
| OP_ATTR(pad_top, long) | |||||
| OP_ATTR_WITH_VALUE(trans_format, bool, false) | |||||
| OP_SCHEMA_DEF_END(SplitWithOverlap) | |||||
| @@ -237,6 +237,7 @@ | |||||
| #include "ops/log_softmax.h" | #include "ops/log_softmax.h" | ||||
| #include "ops/call.h" | #include "ops/call.h" | ||||
| #include "ops/cumsum.h" | #include "ops/cumsum.h" | ||||
| #include "ops/split_with_overlap.h" | |||||
| #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ | #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ | ||||
| namespace mindspore::lite::ops { \ | namespace mindspore::lite::ops { \ | ||||
| @@ -448,5 +449,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Splice); | |||||
| FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax); | FUNC_MSOP2SCHEMAOP_DECLARE(LogSoftmax); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(Call); | FUNC_MSOP2SCHEMAOP_DECLARE(Call); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(CumSum); | FUNC_MSOP2SCHEMAOP_DECLARE(CumSum); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(SplitWithOverlap); | |||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | ||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| #include "nnacl/split_parameter.h" | |||||
| using mindspore::schema::PrimitiveType_SplitWithOverlap; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateSplitWithOverlapParameter(const void *prim) { | |||||
| auto *split_with_over_lap_param = | |||||
| reinterpret_cast<SplitWithOverlapParameter *>(malloc(sizeof(SplitWithOverlapParameter))); | |||||
| if (split_with_over_lap_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PopulateSplitWithOverlapParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(split_with_over_lap_param, 0, sizeof(SplitWithOverlapParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| auto value = primitive->value_as_SplitWithOverlap(); | |||||
| split_with_over_lap_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto ratio = value->ratio(); | |||||
| if (ratio->size() > SPLIT_MAX_SLICE_NUM) { | |||||
| MS_LOG(ERROR) << "SplitWithOverlap do not support splitting tensor into more than " << SPLIT_MAX_SLICE_NUM | |||||
| << " slices"; | |||||
| delete split_with_over_lap_param; | |||||
| return nullptr; | |||||
| } | |||||
| split_with_over_lap_param->num_split_ = static_cast<int>(ratio->size()); | |||||
| split_with_over_lap_param->split_dim_ = value->split_dim(); | |||||
| auto extend_top = value->extend_top(); | |||||
| auto extend_bottom = value->extend_bottom(); | |||||
| if (extend_top->size() != ratio->size() || extend_bottom->size() != ratio->size()) { | |||||
| MS_LOG(ERROR) << "The sizes of ratio, extend_top and extend_bottom are not identical"; | |||||
| delete split_with_over_lap_param; | |||||
| return nullptr; | |||||
| } | |||||
| for (size_t i = 0; i < ratio->size(); ++i) { | |||||
| split_with_over_lap_param->ratio_[i] = (*ratio)[i]; | |||||
| split_with_over_lap_param->extend_top_[i] = (*extend_top)[i]; | |||||
| split_with_over_lap_param->extend_bottom_[i] = (*extend_bottom)[i]; | |||||
| } | |||||
| split_with_over_lap_param->stride_ = value->stride(); | |||||
| split_with_over_lap_param->pad_top_ = value->pad_top(); | |||||
| return reinterpret_cast<OpParameter *>(split_with_over_lap_param); | |||||
| } | |||||
| REG_POPULATE(PrimitiveType_SplitWithOverlap, PopulateSplitWithOverlapParameter, SCHEMA_CUR) | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,133 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/split_with_over_lap_base.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_SplitWithOverlap; | |||||
| namespace mindspore::kernel { | |||||
| void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param, | |||||
| const std::vector<int> &shape) { | |||||
| int total_block_count = 0; | |||||
| for (auto i = 0; i < param->num_split_; i++) { | |||||
| total_block_count += param->ratio_[i]; | |||||
| } | |||||
| auto split_dim_size = shape[param->split_dim_]; | |||||
| std::vector<int> borders; | |||||
| borders.emplace_back(0); | |||||
| int visited_block = 0; | |||||
| for (auto i = 0; i < param->num_split_ - 1; i++) { | |||||
| visited_block += param->ratio_[i]; | |||||
| auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); | |||||
| if (param->stride_ != 0) { | |||||
| // make sure border align with stride | |||||
| cur_border = UP_ROUND(cur_border + param->pad_top_, param->stride_); | |||||
| borders.emplace_back(cur_border - param->pad_top_); | |||||
| } else { | |||||
| borders.emplace_back(cur_border); | |||||
| } | |||||
| } | |||||
| borders.emplace_back(split_dim_size); | |||||
| for (auto i = 0; i < param->num_split_; i++) { | |||||
| start_indices_.emplace_back(borders[i]); | |||||
| end_indices_.emplace_back(borders[i + 1]); | |||||
| // overlap: calibrate start_indices and end_indices by adding extends | |||||
| start_indices_[i] -= param->extend_top_[i]; | |||||
| end_indices_[i] += param->extend_bottom_[i]; | |||||
| } | |||||
| } | |||||
| int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; } | |||||
| int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; } | |||||
| int SplitWithOverlapBaseCPUKernel::Split(int task_id) { | |||||
| DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_, | |||||
| inner_stride_, start_indices_.data(), end_indices_.data()); | |||||
| return RET_OK; | |||||
| } | |||||
| int SplitWithOverlapRun(void *cdata, int task_id) { | |||||
| auto g_kernel = reinterpret_cast<SplitWithOverlapBaseCPUKernel *>(cdata); | |||||
| auto ret = g_kernel->Split(task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SplitWithOverlapRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int SplitWithOverlapBaseCPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| auto in_tensor = in_tensors_.front(); | |||||
| input_ptr_ = reinterpret_cast<char *>(in_tensor->data_c()); | |||||
| auto input_shape = in_tensor->shape(); | |||||
| start_indices_.clear(); | |||||
| end_indices_.clear(); | |||||
| output_ptr_.clear(); | |||||
| for (int i = 0; i < param->num_split_; i++) { | |||||
| output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c())); | |||||
| } | |||||
| CalculateSplitedShapes(param, input_shape); | |||||
| outer_total_dim_ = 1; | |||||
| inner_stride_ = 1; | |||||
| split_dim_size_ = input_shape[param->split_dim_]; | |||||
| element_bytes_ = in_tensor->Size(); | |||||
| for (auto i = 0; i < param->split_dim_; i++) { | |||||
| outer_total_dim_ *= input_shape[i]; | |||||
| } | |||||
| for (int i = static_cast<int>(input_shape.size()) - 1; i > param->split_dim_; i--) { | |||||
| inner_stride_ *= input_shape[i]; | |||||
| } | |||||
| auto ret = ParallelLaunch(static_cast<const lite::InnerContext *>(this->context_)->thread_pool_, SplitWithOverlapRun, | |||||
| this, param->num_split_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ | |||||
| #include <vector> | |||||
| #include "include/errorcode.h" | |||||
| #include "include/context.h" | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/split_parameter.h" | |||||
| #include "nnacl/base/split_with_over_lap_base.h" | |||||
| namespace mindspore::kernel { | |||||
| class SplitWithOverlapBaseCPUKernel : public LiteKernel { | |||||
| public: | |||||
| SplitWithOverlapBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) { | |||||
| param = reinterpret_cast<SplitWithOverlapParameter *>(op_parameter_); | |||||
| } | |||||
| ~SplitWithOverlapBaseCPUKernel() override = default; | |||||
| void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector<int> &shape); | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int Split(int task_id); | |||||
| protected: | |||||
| // range: [start, end) | |||||
| std::vector<int> start_indices_; | |||||
| std::vector<int> end_indices_; | |||||
| int outer_total_dim_{0}; | |||||
| int inner_stride_{0}; | |||||
| int element_bytes_{0}; | |||||
| int split_dim_size_{0}; | |||||
| SplitWithOverlapParameter *param = nullptr; | |||||
| char *input_ptr_{nullptr}; | |||||
| std::vector<char *> output_ptr_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_WITH_OVER_LAP_BASE_H_ | |||||