From: @zoloft Reviewed-by: @wangchengyuan Signed-off-by: @wangchengyuantags/v1.2.0-rc1
| @@ -226,6 +226,9 @@ constexpr auto kCoeff = "coeff"; | |||||
| constexpr auto kIsDepthWise = "is_depth_wise"; | constexpr auto kIsDepthWise = "is_depth_wise"; | ||||
| constexpr auto kZoneoutCell = "zoneout_cell"; | constexpr auto kZoneoutCell = "zoneout_cell"; | ||||
| constexpr auto kZoneoutHidden = "zoneout_hidden"; | constexpr auto kZoneoutHidden = "zoneout_hidden"; | ||||
| constexpr auto kSpliceContext = "context"; | |||||
| constexpr auto kSpliceForwardIndexes = "forward_indexes"; | |||||
| constexpr auto kSpliceOutputDims = "output_dim"; | |||||
| const std::set<TypeId> common_valid_types = { | const std::set<TypeId> common_valid_types = { | ||||
| kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, | kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, | ||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/splice.h" | |||||
| #include <vector> | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void Splice::Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes, | |||||
| int64_t output_dims) { | |||||
| this->set_context(contexts); | |||||
| this->set_forward_indexes(forward_indexes); | |||||
| this->set_output_dim(output_dims); | |||||
| } | |||||
| void Splice::set_context(const std::vector<int64_t> &contexts) { this->AddAttr(kSpliceContext, MakeValue(contexts)); } | |||||
| void Splice::set_forward_indexes(const std::vector<int64_t> &forward_indexes) { | |||||
| this->AddAttr(kSpliceForwardIndexes, MakeValue(forward_indexes)); | |||||
| } | |||||
| void Splice::set_output_dim(int64_t output_dim) { this->AddAttr(kSpliceOutputDims, MakeValue(output_dim)); } | |||||
| std::vector<int64_t> Splice::get_context() const { | |||||
| auto value_ptr = GetAttr(kSpliceContext); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Splice::get_forward_indexes() const { | |||||
| auto value_ptr = GetAttr(kSpliceForwardIndexes); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| int64_t Splice::get_output_dim() const { | |||||
| auto value_ptr = GetAttr(kSpliceOutputDims); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameSplice, Splice); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_SPLICE_H_ | |||||
| #define MINDSPORE_CORE_OPS_SPLICE_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameSplice = "Splice"; | |||||
| class Splice : public PrimitiveC { | |||||
| public: | |||||
| Splice() : PrimitiveC(kNameSplice) { InitIOName({"inputs"}, {"outputs"}); } | |||||
| ~Splice() = default; | |||||
| MS_DECLARE_PARENT(Splice, PrimitiveC); | |||||
| void Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes, int64_t output_dims); | |||||
| void set_context(const std::vector<int64_t> &contexts); | |||||
| void set_forward_indexes(const std::vector<int64_t> &forward_indexes); | |||||
| void set_output_dim(int64_t output_dim); | |||||
| std::vector<int64_t> get_context() const; | |||||
| std::vector<int64_t> get_forward_indexes() const; | |||||
| int64_t get_output_dim() const; | |||||
| AbstractBasePtr SpliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_SPLICE_H_ | |||||
| @@ -91,6 +91,7 @@ set(CODER_OPCODERS_SRC | |||||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc | |||||
| #### nnacl int8 coder | #### nnacl int8 coder | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc | ||||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/add_int8_coder.cc | ${MICRO_DIR}/coder/opcoders/nnacl/int8/add_int8_coder.cc | ||||
| @@ -161,6 +162,7 @@ set(LITE_SRC | |||||
| ${LITE_DIR}/src/ops/populate/bias_add_populate.cc | ${LITE_DIR}/src/ops/populate/bias_add_populate.cc | ||||
| ${LITE_DIR}/src/ops/populate/activation_populate.cc | ${LITE_DIR}/src/ops/populate/activation_populate.cc | ||||
| ${LITE_DIR}/src/ops/populate/softmax_populate.cc | ${LITE_DIR}/src/ops/populate/softmax_populate.cc | ||||
| ${LITE_DIR}/src/ops/populate/splice_populate.cc | |||||
| ### tools | ### tools | ||||
| ${LITE_DIR}/tools/common/flag_parser.cc | ${LITE_DIR}/tools/common/flag_parser.cc | ||||
| ) | ) | ||||
| @@ -32,14 +32,10 @@ int SpliceFP32Coder::DoCode(CoderContext *const context) { | |||||
| MS_LOG(ERROR) << "SpliceFP32Coder src_shape size not equal to dst_shape"; | MS_LOG(ERROR) << "SpliceFP32Coder src_shape size not equal to dst_shape"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| int src_row = src_shape.at(kInputIndex); | |||||
| int dst_row = dst_shape.at(kInputIndex); | |||||
| int src_row = src_shape.at(kWeightIndex); | |||||
| int dst_row = dst_shape.at(kWeightIndex); | |||||
| int src_col = src_shape.at(kBiasIndex); | int src_col = src_shape.at(kBiasIndex); | ||||
| int dst_col = dst_shape.at(kBiasIndex); | int dst_col = dst_shape.at(kBiasIndex); | ||||
| if (src_row != dst_row) { | |||||
| MS_LOG(ERROR) << "SpliceFP32Coder src_row not equal to dst_row"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (src_col * splice_parameter->context_dim_ != dst_col) { | if (src_col * splice_parameter->context_dim_ != dst_col) { | ||||
| MS_LOG(ERROR) << "SpliceFP32Coder src_col not match to dst_col"; | MS_LOG(ERROR) << "SpliceFP32Coder src_col not match to dst_col"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -116,6 +116,7 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg & | |||||
| void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceParameter &splice_parameter) { | void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceParameter &splice_parameter) { | ||||
| CodeArray("splice_context", splice_parameter.context_, splice_parameter.context_dim_, false); | CodeArray("splice_context", splice_parameter.context_, splice_parameter.context_dim_, false); | ||||
| CodeBaseStruct("SpliceParameter", name, splice_parameter.op_parameter_, splice_parameter.context_dim_, | CodeBaseStruct("SpliceParameter", name, splice_parameter.op_parameter_, splice_parameter.context_dim_, | ||||
| splice_parameter.forward_indexes_dim_, "splice_context", nullptr, splice_parameter.output_dim_); | |||||
| splice_parameter.forward_indexes_dim_, splice_parameter.src_to_dst_row_offset_, "splice_context", | |||||
| nullptr, splice_parameter.output_dim_); | |||||
| } | } | ||||
| } // namespace mindspore::lite::micro::nnacl | } // namespace mindspore::lite::micro::nnacl | ||||
| @@ -17,13 +17,14 @@ | |||||
| #include "nnacl/fp32/splice_fp32.h" | #include "nnacl/fp32/splice_fp32.h" | ||||
| void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, | void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, | ||||
| float *dst_data, int dst_row, int dst_col) { | float *dst_data, int dst_row, int dst_col) { | ||||
| int row_offset = splice_parameter->src_to_dst_row_offset_; | |||||
| for (int r = 0; r < dst_row; ++r) { | for (int r = 0; r < dst_row; ++r) { | ||||
| for (int off = 0; off < splice_parameter->context_dim_; ++off) { | for (int off = 0; off < splice_parameter->context_dim_; ++off) { | ||||
| int r_off = r + splice_parameter->context_[off]; | |||||
| int r_off = r + row_offset + splice_parameter->context_[off]; | |||||
| r_off = MSMAX(r_off, 0); | r_off = MSMAX(r_off, 0); | ||||
| r_off = MSMIN(r_off, src_row - 1); | r_off = MSMIN(r_off, src_row - 1); | ||||
| const float *tmp_src_data = src_data + r_off * src_col * sizeof(float); | |||||
| float *tmp_dst_data = dst_data + r * dst_col * sizeof(float); | |||||
| const float *tmp_src_data = src_data + r_off * src_col; | |||||
| float *tmp_dst_data = dst_data + r * dst_col; | |||||
| memcpy(tmp_dst_data + off * src_col, tmp_src_data, src_col * sizeof(float)); | memcpy(tmp_dst_data + off * src_col, tmp_src_data, src_col * sizeof(float)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -212,8 +212,9 @@ enum PrimType { | |||||
| PrimType_SqrtGrad = 185, | PrimType_SqrtGrad = 185, | ||||
| PrimType_LayerNormGrad = 186, | PrimType_LayerNormGrad = 186, | ||||
| PrimType_ResizeGrad = 187, | PrimType_ResizeGrad = 187, | ||||
| PrimType_Splice = 188, | |||||
| PrimType_MIN = PrimType_NONE, | PrimType_MIN = PrimType_NONE, | ||||
| PrimType_MAX = PrimType_ResizeGrad | |||||
| PrimType_MAX = PrimType_Splice | |||||
| }; | }; | ||||
| void RegInfer(int prim_type, InferShape func); | void RegInfer(int prim_type, InferShape func); | ||||
| @@ -54,3 +54,5 @@ int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * | |||||
| output->shape_[2] = out_dim; | output->shape_[2] = out_dim; | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| REG_INFER(Splice, PrimType_Splice, SpliceInferShape) | |||||
| @@ -21,6 +21,7 @@ typedef struct SpliceParameter { | |||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| int context_dim_; | int context_dim_; | ||||
| int forward_indexes_dim_; | int forward_indexes_dim_; | ||||
| int src_to_dst_row_offset_; | |||||
| int *context_; | int *context_; | ||||
| int *forward_indexes_; | int *forward_indexes_; | ||||
| int output_dim_; | int output_dim_; | ||||
| @@ -205,6 +205,7 @@ union PrimitiveType { | |||||
| SqrtGrad, | SqrtGrad, | ||||
| LayerNormGrad, | LayerNormGrad, | ||||
| ResizeGrad, | ResizeGrad, | ||||
| Splice, | |||||
| } | } | ||||
| table Abs { | table Abs { | ||||
| @@ -1087,3 +1088,10 @@ table ResizeGrad { | |||||
| method: ResizeMethod; | method: ResizeMethod; | ||||
| align_corners: bool; | align_corners: bool; | ||||
| } | } | ||||
| table Splice { | |||||
| context: [long]; | |||||
| forward_indexes: [long]; | |||||
| output_dim: long; | |||||
| } | |||||
| @@ -204,6 +204,7 @@ OP_TYPE(RsqrtGrad) | |||||
| OP_TYPE(SqrtGrad) | OP_TYPE(SqrtGrad) | ||||
| OP_TYPE(LayerNormGrad) | OP_TYPE(LayerNormGrad) | ||||
| OP_TYPE(ResizeGrad) | OP_TYPE(ResizeGrad) | ||||
| OP_TYPE(Splice) | |||||
| OP_TYPE_DEF_END(PrimitiveType) | OP_TYPE_DEF_END(PrimitiveType) | ||||
| OP_SCHEMA_DEF(Abs) | OP_SCHEMA_DEF(Abs) | ||||
| @@ -1086,3 +1087,9 @@ OP_SCHEMA_DEF(ResizeGrad) | |||||
| OP_ATTR_ENUM(method, ResizeMethod) | OP_ATTR_ENUM(method, ResizeMethod) | ||||
| OP_ATTR(align_corners, bool) | OP_ATTR(align_corners, bool) | ||||
| OP_SCHEMA_DEF_END(ResizeGrad) | OP_SCHEMA_DEF_END(ResizeGrad) | ||||
| OP_SCHEMA_DEF(Splice) | |||||
| OP_ATTR(context, [long]) | |||||
| OP_ATTR(forward_indexes, [long]) | |||||
| OP_ATTR(output_dim, long) | |||||
| OP_SCHEMA_DEF_END(Splice) | |||||
| @@ -242,6 +242,7 @@ | |||||
| #include "ops/lin_space.h" | #include "ops/lin_space.h" | ||||
| #include "ops/uniform_real.h" | #include "ops/uniform_real.h" | ||||
| #include "ops/grad/abs_grad.h" | #include "ops/grad/abs_grad.h" | ||||
| #include "ops/splice.h" | |||||
| #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ | #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ | ||||
| namespace mindspore::lite::ops { \ | namespace mindspore::lite::ops { \ | ||||
| @@ -453,5 +454,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(RsqrtGrad); | |||||
| FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad); | FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad); | FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad); | FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(Splice); | |||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | ||||
| @@ -745,6 +745,11 @@ schema::PrimitiveT *ErfPrimitiveCreator(const AnfNodePtr &node) { | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *SplicePrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Splice>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | ||||
| RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); | RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); | ||||
| RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | ||||
| @@ -954,6 +959,7 @@ RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiv | |||||
| RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); | RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); | ||||
| RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); | RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); | ||||
| RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); | RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); | ||||
| RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/splice_parameter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateSpliceParameter(const void *prim) { | |||||
| auto *splice_parameter = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter))); | |||||
| if (splice_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc Splice Parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(splice_parameter, 0, sizeof(SpliceParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| auto splice_primitive = primitive->value_as_Splice(); | |||||
| splice_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| std::vector<int> primitive_context(splice_primitive->context()->begin(), splice_primitive->context()->end()); | |||||
| splice_parameter->context_dim_ = static_cast<int>(primitive_context.size()); | |||||
| // malloc && memset for context | |||||
| splice_parameter->context_ = reinterpret_cast<int *>(malloc(splice_parameter->context_dim_ * sizeof(int))); | |||||
| if (splice_parameter->context_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc splice_parameter context_ error"; | |||||
| free(splice_parameter); | |||||
| return nullptr; | |||||
| } | |||||
| // src_to_dst_row_offset | |||||
| int src_to_dst_row_offset = INT32_MIN; | |||||
| memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int)); | |||||
| for (int i = 0; i < splice_parameter->context_dim_; ++i) { | |||||
| splice_parameter->context_[i] = primitive_context.at(i); | |||||
| src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); | |||||
| } | |||||
| std::vector<int> primitive_forward_indexes(splice_primitive->forward_indexes()->begin(), | |||||
| splice_primitive->forward_indexes()->end()); | |||||
| splice_parameter->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size()); | |||||
| // malloc && memset for forward_indexes | |||||
| splice_parameter->forward_indexes_ = | |||||
| reinterpret_cast<int *>(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int))); | |||||
| if (splice_parameter->forward_indexes_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error"; | |||||
| free(splice_parameter->context_); | |||||
| free(splice_parameter); | |||||
| return nullptr; | |||||
| } | |||||
| memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int)); | |||||
| for (int i = 0; i < splice_parameter->context_dim_; ++i) { | |||||
| splice_parameter->context_[i] = primitive_context.at(i); | |||||
| } | |||||
| splice_parameter->output_dim_ = splice_primitive->output_dim(); | |||||
| return reinterpret_cast<OpParameter *>(splice_parameter); | |||||
| } | |||||
| Registry g_SpliceParameterRegistry(schema::PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/onnx/onnx_splice_parser.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/splice.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| ops::PrimitiveC *OnnxSpliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx Splice Parser"; | |||||
| auto primitive = std::make_unique<ops::Splice>(); | |||||
| std::vector<int64_t> context; | |||||
| std::vector<int64_t> forward_indexes; | |||||
| int64_t output_dim = 0; | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||||
| const std::string attribute_name = onnx_node_attr.name(); | |||||
| if (attribute_name == "context") { | |||||
| const int32_t size = onnx_node_attr.ints_size(); | |||||
| context.resize(size); | |||||
| for (int32_t i = 0; i < size; i++) { | |||||
| context[i] = static_cast<int>(onnx_node_attr.ints(i)); | |||||
| } | |||||
| } else if (attribute_name == "forward_indexes") { | |||||
| const int32_t size = onnx_node_attr.ints_size(); | |||||
| forward_indexes.resize(size); | |||||
| for (int32_t i = 0; i < size; i++) { | |||||
| forward_indexes[i] = static_cast<int>(onnx_node_attr.ints(i)); | |||||
| } | |||||
| } else if (attribute_name == "output_dim") { | |||||
| output_dim = static_cast<int>(onnx_node_attr.i()); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unsupported attribute in splice " << attribute_name; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| primitive->Init(context, forward_indexes, output_dim); | |||||
| return primitive.release(); | |||||
| } | |||||
| OnnxNodeRegistrar g_onnxSpliceParser("Splice", new OnnxSpliceParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class OnnxSpliceParser : public OnnxNodeParser { | |||||
| public: | |||||
| OnnxSpliceParser() : OnnxNodeParser("Splice") {} | |||||
| ~OnnxSpliceParser() override = default; | |||||
| ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H | |||||