diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 7e3eba91ee..51436c0042 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -226,6 +226,9 @@ constexpr auto kCoeff = "coeff"; constexpr auto kIsDepthWise = "is_depth_wise"; constexpr auto kZoneoutCell = "zoneout_cell"; constexpr auto kZoneoutHidden = "zoneout_hidden"; +constexpr auto kSpliceContext = "context"; +constexpr auto kSpliceForwardIndexes = "forward_indexes"; +constexpr auto kSpliceOutputDims = "output_dim"; const std::set common_valid_types = { kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, diff --git a/mindspore/core/ops/splice.cc b/mindspore/core/ops/splice.cc new file mode 100644 index 0000000000..d5b8c302f1 --- /dev/null +++ b/mindspore/core/ops/splice.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/splice.h" +#include +#include "ops/op_utils.h" +namespace mindspore { +namespace ops { +void Splice::Init(const std::vector &contexts, const std::vector &forward_indexes, + int64_t output_dims) { + this->set_context(contexts); + this->set_forward_indexes(forward_indexes); + this->set_output_dim(output_dims); +} + +void Splice::set_context(const std::vector &contexts) { this->AddAttr(kSpliceContext, MakeValue(contexts)); } + +void Splice::set_forward_indexes(const std::vector &forward_indexes) { + this->AddAttr(kSpliceForwardIndexes, MakeValue(forward_indexes)); +} + +void Splice::set_output_dim(int64_t output_dim) { this->AddAttr(kSpliceOutputDims, MakeValue(output_dim)); } + +std::vector Splice::get_context() const { + auto value_ptr = GetAttr(kSpliceContext); + return GetValue>(value_ptr); +} + +std::vector Splice::get_forward_indexes() const { + auto value_ptr = GetAttr(kSpliceForwardIndexes); + return GetValue>(value_ptr); +} + +int64_t Splice::get_output_dim() const { + auto value_ptr = GetAttr(kSpliceOutputDims); + return GetValue(value_ptr); +} + +REGISTER_PRIMITIVE_C(kNameSplice, Splice); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/splice.h b/mindspore/core/ops/splice.h new file mode 100644 index 0000000000..b9f1f69305 --- /dev/null +++ b/mindspore/core/ops/splice.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPLICE_H_ +#define MINDSPORE_CORE_OPS_SPLICE_H_ +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSplice = "Splice"; +class Splice : public PrimitiveC { + public: + Splice() : PrimitiveC(kNameSplice) { InitIOName({"inputs"}, {"outputs"}); } + ~Splice() = default; + MS_DECLARE_PARENT(Splice, PrimitiveC); + void Init(const std::vector &contexts, const std::vector &forward_indexes, int64_t output_dims); + void set_context(const std::vector &contexts); + void set_forward_indexes(const std::vector &forward_indexes); + void set_output_dim(int64_t output_dim); + + std::vector get_context() const; + std::vector get_forward_indexes() const; + int64_t get_output_dim() const; + AbstractBasePtr SpliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPLICE_H_ diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index d8d899b499..8280b42461 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -91,6 +91,7 @@ set(CODER_OPCODERS_SRC ${MICRO_DIR}/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc + ${MICRO_DIR}/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc #### nnacl int8 coder ${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/int8/add_int8_coder.cc @@ -161,6 +162,7 @@ set(LITE_SRC ${LITE_DIR}/src/ops/populate/bias_add_populate.cc ${LITE_DIR}/src/ops/populate/activation_populate.cc ${LITE_DIR}/src/ops/populate/softmax_populate.cc + ${LITE_DIR}/src/ops/populate/splice_populate.cc ### tools ${LITE_DIR}/tools/common/flag_parser.cc ) diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc index 739287893f..372e6b9b34 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc @@ -32,14 +32,10 @@ int SpliceFP32Coder::DoCode(CoderContext *const context) { MS_LOG(ERROR) << "SpliceFP32Coder src_shape size not equal to dst_shape"; return RET_ERROR; } - int src_row = src_shape.at(kInputIndex); - int dst_row = dst_shape.at(kInputIndex); + int src_row = src_shape.at(kWeightIndex); + int dst_row = dst_shape.at(kWeightIndex); int src_col = src_shape.at(kBiasIndex); int dst_col = dst_shape.at(kBiasIndex); - if (src_row != dst_row) { - MS_LOG(ERROR) << "SpliceFP32Coder src_row not equal to dst_row"; - return RET_ERROR; - } if (src_col * splice_parameter->context_dim_ != dst_col) { MS_LOG(ERROR) << "SpliceFP32Coder src_col not match to dst_col"; return RET_ERROR; diff --git a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc index 414e07ee34..a648edcca9 100644 --- a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc +++ b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc @@ -116,6 +116,7 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg & void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceParameter &splice_parameter) { CodeArray("splice_context", splice_parameter.context_, splice_parameter.context_dim_, false); CodeBaseStruct("SpliceParameter", name, splice_parameter.op_parameter_, splice_parameter.context_dim_, - splice_parameter.forward_indexes_dim_, "splice_context", nullptr, splice_parameter.output_dim_); + splice_parameter.forward_indexes_dim_, splice_parameter.src_to_dst_row_offset_, "splice_context", + nullptr, splice_parameter.output_dim_); } } // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/nnacl/fp32/splice_fp32.c b/mindspore/lite/nnacl/fp32/splice_fp32.c index 0682925169..7b4789e62d 100644 --- a/mindspore/lite/nnacl/fp32/splice_fp32.c +++ b/mindspore/lite/nnacl/fp32/splice_fp32.c @@ -17,13 +17,14 @@ #include "nnacl/fp32/splice_fp32.h" void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, float *dst_data, int dst_row, int dst_col) { + int row_offset = splice_parameter->src_to_dst_row_offset_; for (int r = 0; r < dst_row; ++r) { for (int off = 0; off < splice_parameter->context_dim_; ++off) { - int r_off = r + splice_parameter->context_[off]; + int r_off = r + row_offset + splice_parameter->context_[off]; r_off = MSMAX(r_off, 0); r_off = MSMIN(r_off, src_row - 1); - const float *tmp_src_data = src_data + r_off * src_col * sizeof(float); - float *tmp_dst_data = dst_data + r * dst_col * sizeof(float); + const float *tmp_src_data = src_data + r_off * src_col; + float *tmp_dst_data = dst_data + r * dst_col; memcpy(tmp_dst_data + off * src_col, tmp_src_data, src_col * sizeof(float)); } } diff --git a/mindspore/lite/nnacl/infer/infer_register.h b/mindspore/lite/nnacl/infer/infer_register.h index edb7325570..ad5b3ebaf4 100644 --- a/mindspore/lite/nnacl/infer/infer_register.h +++ b/mindspore/lite/nnacl/infer/infer_register.h @@ -212,8 +212,9 @@ enum PrimType { PrimType_SqrtGrad = 185, PrimType_LayerNormGrad = 186, PrimType_ResizeGrad = 187, + PrimType_Splice = 188, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_ResizeGrad + PrimType_MAX = PrimType_Splice }; void RegInfer(int prim_type, InferShape func); diff --git a/mindspore/lite/nnacl/infer/splice_infer.c b/mindspore/lite/nnacl/infer/splice_infer.c index f1fc80c79a..f82beaaae3 100644 --- a/mindspore/lite/nnacl/infer/splice_infer.c +++ b/mindspore/lite/nnacl/infer/splice_infer.c @@ -54,3 +54,5 @@ int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * output->shape_[2] = out_dim; return NNACL_OK; } + +REG_INFER(Splice, PrimType_Splice, SpliceInferShape) diff --git a/mindspore/lite/nnacl/splice_parameter.h b/mindspore/lite/nnacl/splice_parameter.h index d9cc2a45e2..8063960af7 100644 --- a/mindspore/lite/nnacl/splice_parameter.h +++ b/mindspore/lite/nnacl/splice_parameter.h @@ -21,6 +21,7 @@ typedef struct SpliceParameter { OpParameter op_parameter_; int context_dim_; int forward_indexes_dim_; + int src_to_dst_row_offset_; int *context_; int *forward_indexes_; int output_dim_; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 8f99072bda..979d2e6457 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -205,6 +205,7 @@ union PrimitiveType { SqrtGrad, LayerNormGrad, ResizeGrad, + Splice, } table Abs { @@ -1087,3 +1088,10 @@ table ResizeGrad { method: ResizeMethod; align_corners: bool; } + +table Splice { + context: [long]; + forward_indexes: [long]; + output_dim: long; +} + diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index 4cf237c54b..679ee21920 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -204,6 +204,7 @@ OP_TYPE(RsqrtGrad) OP_TYPE(SqrtGrad) OP_TYPE(LayerNormGrad) OP_TYPE(ResizeGrad) +OP_TYPE(Splice) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1086,3 +1087,9 @@ OP_SCHEMA_DEF(ResizeGrad) OP_ATTR_ENUM(method, ResizeMethod) OP_ATTR(align_corners, bool) OP_SCHEMA_DEF_END(ResizeGrad) + +OP_SCHEMA_DEF(Splice) +OP_ATTR(context, [long]) +OP_ATTR(forward_indexes, [long]) +OP_ATTR(output_dim, long) +OP_SCHEMA_DEF_END(Splice) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index a2724ab218..fa150d92c3 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -242,6 +242,7 @@ #include "ops/lin_space.h" #include "ops/uniform_real.h" #include "ops/grad/abs_grad.h" +#include "ops/splice.h" #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ namespace mindspore::lite::ops { \ @@ -453,5 +454,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(RsqrtGrad); FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad); FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad); FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad); +FUNC_MSOP2SCHEMAOP_DECLARE(Splice); #endif #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index f868cb08a7..a7a73f5432 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -745,6 +745,11 @@ schema::PrimitiveT *ErfPrimitiveCreator(const AnfNodePtr &node) { return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *SplicePrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} + RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); @@ -954,6 +959,7 @@ RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiv RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); +RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/splice_populate.cc b/mindspore/lite/src/ops/populate/splice_populate.cc new file mode 100644 index 0000000000..6b74e3ef53 --- /dev/null +++ b/mindspore/lite/src/ops/populate/splice_populate.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/populate/populate_register.h" +#include "nnacl/op_base.h" +#include "nnacl/splice_parameter.h" +namespace mindspore { +namespace lite { +OpParameter *PopulateSpliceParameter(const void *prim) { + auto *splice_parameter = reinterpret_cast(malloc(sizeof(SpliceParameter))); + if (splice_parameter == nullptr) { + MS_LOG(ERROR) << "malloc Splice Parameter failed."; + return nullptr; + } + memset(splice_parameter, 0, sizeof(SpliceParameter)); + auto primitive = static_cast(prim); + auto splice_primitive = primitive->value_as_Splice(); + splice_parameter->op_parameter_.type_ = primitive->value_type(); + + std::vector primitive_context(splice_primitive->context()->begin(), splice_primitive->context()->end()); + splice_parameter->context_dim_ = static_cast(primitive_context.size()); + + // malloc && memset for context + splice_parameter->context_ = reinterpret_cast(malloc(splice_parameter->context_dim_ * sizeof(int))); + if (splice_parameter->context_ == nullptr) { + MS_LOG(ERROR) << "malloc splice_parameter context_ error"; + free(splice_parameter); + return nullptr; + } + // src_to_dst_row_offset + int src_to_dst_row_offset = INT32_MIN; + memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int)); + for (int i = 0; i < splice_parameter->context_dim_; ++i) { + splice_parameter->context_[i] = primitive_context.at(i); + src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); + } + + std::vector primitive_forward_indexes(splice_primitive->forward_indexes()->begin(), + splice_primitive->forward_indexes()->end()); + splice_parameter->forward_indexes_dim_ = static_cast(primitive_forward_indexes.size()); + + // malloc && memset for forward_indexes + splice_parameter->forward_indexes_ = + reinterpret_cast(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int))); + if (splice_parameter->forward_indexes_ == nullptr) { + MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error"; + free(splice_parameter->context_); + free(splice_parameter); + return nullptr; + } + memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int)); + for (int i = 0; i < splice_parameter->context_dim_; ++i) { + splice_parameter->context_[i] = primitive_context.at(i); + } + splice_parameter->output_dim_ = splice_primitive->output_dim(); + return reinterpret_cast(splice_parameter); +} +Registry g_SpliceParameterRegistry(schema::PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_splice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_splice_parser.cc new file mode 100644 index 0000000000..9b5cf9ef36 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_splice_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_splice_parser.h" +#include +#include +#include +#include "ops/splice.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *OnnxSpliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + MS_LOG(DEBUG) << "onnx Splice Parser"; + auto primitive = std::make_unique(); + std::vector context; + std::vector forward_indexes; + int64_t output_dim = 0; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const std::string attribute_name = onnx_node_attr.name(); + if (attribute_name == "context") { + const int32_t size = onnx_node_attr.ints_size(); + context.resize(size); + for (int32_t i = 0; i < size; i++) { + context[i] = static_cast(onnx_node_attr.ints(i)); + } + } else if (attribute_name == "forward_indexes") { + const int32_t size = onnx_node_attr.ints_size(); + forward_indexes.resize(size); + for (int32_t i = 0; i < size; i++) { + forward_indexes[i] = static_cast(onnx_node_attr.ints(i)); + } + } else if (attribute_name == "output_dim") { + output_dim = static_cast(onnx_node_attr.i()); + } else { + MS_LOG(ERROR) << "unsupported attribute in splice " << attribute_name; + return nullptr; + } + } + primitive->Init(context, forward_indexes, output_dim); + return primitive.release(); +} + +OnnxNodeRegistrar g_onnxSpliceParser("Splice", new OnnxSpliceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_splice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_splice_parser.h new file mode 100644 index 0000000000..07542837c0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_splice_parser.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSpliceParser : public OnnxNodeParser { + public: + OnnxSpliceParser() : OnnxNodeParser("Splice") {} + ~OnnxSpliceParser() override = default; + + ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H