From: @zoloft Reviewed-by: @wangchengyuan Signed-off-by: @wangchengyuantags/v1.2.0-rc1
| @@ -226,6 +226,9 @@ constexpr auto kCoeff = "coeff"; | |||
| constexpr auto kIsDepthWise = "is_depth_wise"; | |||
| constexpr auto kZoneoutCell = "zoneout_cell"; | |||
| constexpr auto kZoneoutHidden = "zoneout_hidden"; | |||
| constexpr auto kSpliceContext = "context"; | |||
| constexpr auto kSpliceForwardIndexes = "forward_indexes"; | |||
| constexpr auto kSpliceOutputDims = "output_dim"; | |||
| const std::set<TypeId> common_valid_types = { | |||
| kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/splice.h" | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void Splice::Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes, | |||
| int64_t output_dims) { | |||
| this->set_context(contexts); | |||
| this->set_forward_indexes(forward_indexes); | |||
| this->set_output_dim(output_dims); | |||
| } | |||
| void Splice::set_context(const std::vector<int64_t> &contexts) { this->AddAttr(kSpliceContext, MakeValue(contexts)); } | |||
| void Splice::set_forward_indexes(const std::vector<int64_t> &forward_indexes) { | |||
| this->AddAttr(kSpliceForwardIndexes, MakeValue(forward_indexes)); | |||
| } | |||
| void Splice::set_output_dim(int64_t output_dim) { this->AddAttr(kSpliceOutputDims, MakeValue(output_dim)); } | |||
| std::vector<int64_t> Splice::get_context() const { | |||
| auto value_ptr = GetAttr(kSpliceContext); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Splice::get_forward_indexes() const { | |||
| auto value_ptr = GetAttr(kSpliceForwardIndexes); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| int64_t Splice::get_output_dim() const { | |||
| auto value_ptr = GetAttr(kSpliceOutputDims); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameSplice, Splice); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_SPLICE_H_ | |||
| #define MINDSPORE_CORE_OPS_SPLICE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameSplice = "Splice"; | |||
| class Splice : public PrimitiveC { | |||
| public: | |||
| Splice() : PrimitiveC(kNameSplice) { InitIOName({"inputs"}, {"outputs"}); } | |||
| ~Splice() = default; | |||
| MS_DECLARE_PARENT(Splice, PrimitiveC); | |||
| void Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes, int64_t output_dims); | |||
| void set_context(const std::vector<int64_t> &contexts); | |||
| void set_forward_indexes(const std::vector<int64_t> &forward_indexes); | |||
| void set_output_dim(int64_t output_dim); | |||
| std::vector<int64_t> get_context() const; | |||
| std::vector<int64_t> get_forward_indexes() const; | |||
| int64_t get_output_dim() const; | |||
| AbstractBasePtr SpliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| }; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_SPLICE_H_ | |||
| @@ -91,6 +91,7 @@ set(CODER_OPCODERS_SRC | |||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc | |||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc | |||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc | |||
| ${MICRO_DIR}/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc | |||
| #### nnacl int8 coder | |||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc | |||
| ${MICRO_DIR}/coder/opcoders/nnacl/int8/add_int8_coder.cc | |||
| @@ -161,6 +162,7 @@ set(LITE_SRC | |||
| ${LITE_DIR}/src/ops/populate/bias_add_populate.cc | |||
| ${LITE_DIR}/src/ops/populate/activation_populate.cc | |||
| ${LITE_DIR}/src/ops/populate/softmax_populate.cc | |||
| ${LITE_DIR}/src/ops/populate/splice_populate.cc | |||
| ### tools | |||
| ${LITE_DIR}/tools/common/flag_parser.cc | |||
| ) | |||
| @@ -32,14 +32,10 @@ int SpliceFP32Coder::DoCode(CoderContext *const context) { | |||
| MS_LOG(ERROR) << "SpliceFP32Coder src_shape size not equal to dst_shape"; | |||
| return RET_ERROR; | |||
| } | |||
| int src_row = src_shape.at(kInputIndex); | |||
| int dst_row = dst_shape.at(kInputIndex); | |||
| int src_row = src_shape.at(kWeightIndex); | |||
| int dst_row = dst_shape.at(kWeightIndex); | |||
| int src_col = src_shape.at(kBiasIndex); | |||
| int dst_col = dst_shape.at(kBiasIndex); | |||
| if (src_row != dst_row) { | |||
| MS_LOG(ERROR) << "SpliceFP32Coder src_row not equal to dst_row"; | |||
| return RET_ERROR; | |||
| } | |||
| if (src_col * splice_parameter->context_dim_ != dst_col) { | |||
| MS_LOG(ERROR) << "SpliceFP32Coder src_col not match to dst_col"; | |||
| return RET_ERROR; | |||
| @@ -116,6 +116,7 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg & | |||
| void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceParameter &splice_parameter) { | |||
| CodeArray("splice_context", splice_parameter.context_, splice_parameter.context_dim_, false); | |||
| CodeBaseStruct("SpliceParameter", name, splice_parameter.op_parameter_, splice_parameter.context_dim_, | |||
| splice_parameter.forward_indexes_dim_, "splice_context", nullptr, splice_parameter.output_dim_); | |||
| splice_parameter.forward_indexes_dim_, splice_parameter.src_to_dst_row_offset_, "splice_context", | |||
| nullptr, splice_parameter.output_dim_); | |||
| } | |||
| } // namespace mindspore::lite::micro::nnacl | |||
| @@ -17,13 +17,14 @@ | |||
| #include "nnacl/fp32/splice_fp32.h" | |||
| void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, | |||
| float *dst_data, int dst_row, int dst_col) { | |||
| int row_offset = splice_parameter->src_to_dst_row_offset_; | |||
| for (int r = 0; r < dst_row; ++r) { | |||
| for (int off = 0; off < splice_parameter->context_dim_; ++off) { | |||
| int r_off = r + splice_parameter->context_[off]; | |||
| int r_off = r + row_offset + splice_parameter->context_[off]; | |||
| r_off = MSMAX(r_off, 0); | |||
| r_off = MSMIN(r_off, src_row - 1); | |||
| const float *tmp_src_data = src_data + r_off * src_col * sizeof(float); | |||
| float *tmp_dst_data = dst_data + r * dst_col * sizeof(float); | |||
| const float *tmp_src_data = src_data + r_off * src_col; | |||
| float *tmp_dst_data = dst_data + r * dst_col; | |||
| memcpy(tmp_dst_data + off * src_col, tmp_src_data, src_col * sizeof(float)); | |||
| } | |||
| } | |||
| @@ -212,8 +212,9 @@ enum PrimType { | |||
| PrimType_SqrtGrad = 185, | |||
| PrimType_LayerNormGrad = 186, | |||
| PrimType_ResizeGrad = 187, | |||
| PrimType_Splice = 188, | |||
| PrimType_MIN = PrimType_NONE, | |||
| PrimType_MAX = PrimType_ResizeGrad | |||
| PrimType_MAX = PrimType_Splice | |||
| }; | |||
| void RegInfer(int prim_type, InferShape func); | |||
| @@ -54,3 +54,5 @@ int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * | |||
| output->shape_[2] = out_dim; | |||
| return NNACL_OK; | |||
| } | |||
| REG_INFER(Splice, PrimType_Splice, SpliceInferShape) | |||
| @@ -21,6 +21,7 @@ typedef struct SpliceParameter { | |||
| OpParameter op_parameter_; | |||
| int context_dim_; | |||
| int forward_indexes_dim_; | |||
| int src_to_dst_row_offset_; | |||
| int *context_; | |||
| int *forward_indexes_; | |||
| int output_dim_; | |||
| @@ -205,6 +205,7 @@ union PrimitiveType { | |||
| SqrtGrad, | |||
| LayerNormGrad, | |||
| ResizeGrad, | |||
| Splice, | |||
| } | |||
| table Abs { | |||
| @@ -1087,3 +1088,10 @@ table ResizeGrad { | |||
| method: ResizeMethod; | |||
| align_corners: bool; | |||
| } | |||
| table Splice { | |||
| context: [long]; | |||
| forward_indexes: [long]; | |||
| output_dim: long; | |||
| } | |||
| @@ -204,6 +204,7 @@ OP_TYPE(RsqrtGrad) | |||
| OP_TYPE(SqrtGrad) | |||
| OP_TYPE(LayerNormGrad) | |||
| OP_TYPE(ResizeGrad) | |||
| OP_TYPE(Splice) | |||
| OP_TYPE_DEF_END(PrimitiveType) | |||
| OP_SCHEMA_DEF(Abs) | |||
| @@ -1086,3 +1087,9 @@ OP_SCHEMA_DEF(ResizeGrad) | |||
| OP_ATTR_ENUM(method, ResizeMethod) | |||
| OP_ATTR(align_corners, bool) | |||
| OP_SCHEMA_DEF_END(ResizeGrad) | |||
| OP_SCHEMA_DEF(Splice) | |||
| OP_ATTR(context, [long]) | |||
| OP_ATTR(forward_indexes, [long]) | |||
| OP_ATTR(output_dim, long) | |||
| OP_SCHEMA_DEF_END(Splice) | |||
| @@ -242,6 +242,7 @@ | |||
| #include "ops/lin_space.h" | |||
| #include "ops/uniform_real.h" | |||
| #include "ops/grad/abs_grad.h" | |||
| #include "ops/splice.h" | |||
| #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \ | |||
| namespace mindspore::lite::ops { \ | |||
| @@ -453,5 +454,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(RsqrtGrad); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad); | |||
| FUNC_MSOP2SCHEMAOP_DECLARE(Splice); | |||
| #endif | |||
| #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | |||
| @@ -745,6 +745,11 @@ schema::PrimitiveT *ErfPrimitiveCreator(const AnfNodePtr &node) { | |||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||
| } | |||
| schema::PrimitiveT *SplicePrimitiveCreator(const AnfNodePtr &node) { | |||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Splice>>(node); | |||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||
| } | |||
| RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | |||
| RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); | |||
| RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | |||
| @@ -954,6 +959,7 @@ RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiv | |||
| RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); | |||
| RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); | |||
| RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); | |||
| RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/splice_parameter.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateSpliceParameter(const void *prim) { | |||
| auto *splice_parameter = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter))); | |||
| if (splice_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "malloc Splice Parameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(splice_parameter, 0, sizeof(SpliceParameter)); | |||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||
| auto splice_primitive = primitive->value_as_Splice(); | |||
| splice_parameter->op_parameter_.type_ = primitive->value_type(); | |||
| std::vector<int> primitive_context(splice_primitive->context()->begin(), splice_primitive->context()->end()); | |||
| splice_parameter->context_dim_ = static_cast<int>(primitive_context.size()); | |||
| // malloc && memset for context | |||
| splice_parameter->context_ = reinterpret_cast<int *>(malloc(splice_parameter->context_dim_ * sizeof(int))); | |||
| if (splice_parameter->context_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc splice_parameter context_ error"; | |||
| free(splice_parameter); | |||
| return nullptr; | |||
| } | |||
| // src_to_dst_row_offset | |||
| int src_to_dst_row_offset = INT32_MIN; | |||
| memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int)); | |||
| for (int i = 0; i < splice_parameter->context_dim_; ++i) { | |||
| splice_parameter->context_[i] = primitive_context.at(i); | |||
| src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); | |||
| } | |||
| std::vector<int> primitive_forward_indexes(splice_primitive->forward_indexes()->begin(), | |||
| splice_primitive->forward_indexes()->end()); | |||
| splice_parameter->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size()); | |||
| // malloc && memset for forward_indexes | |||
| splice_parameter->forward_indexes_ = | |||
| reinterpret_cast<int *>(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int))); | |||
| if (splice_parameter->forward_indexes_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error"; | |||
| free(splice_parameter->context_); | |||
| free(splice_parameter); | |||
| return nullptr; | |||
| } | |||
| memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int)); | |||
| for (int i = 0; i < splice_parameter->context_dim_; ++i) { | |||
| splice_parameter->context_[i] = primitive_context.at(i); | |||
| } | |||
| splice_parameter->output_dim_ = splice_primitive->output_dim(); | |||
| return reinterpret_cast<OpParameter *>(splice_parameter); | |||
| } | |||
| Registry g_SpliceParameterRegistry(schema::PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_splice_parser.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/splice.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxSpliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx Splice Parser"; | |||
| auto primitive = std::make_unique<ops::Splice>(); | |||
| std::vector<int64_t> context; | |||
| std::vector<int64_t> forward_indexes; | |||
| int64_t output_dim = 0; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const std::string attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "context") { | |||
| const int32_t size = onnx_node_attr.ints_size(); | |||
| context.resize(size); | |||
| for (int32_t i = 0; i < size; i++) { | |||
| context[i] = static_cast<int>(onnx_node_attr.ints(i)); | |||
| } | |||
| } else if (attribute_name == "forward_indexes") { | |||
| const int32_t size = onnx_node_attr.ints_size(); | |||
| forward_indexes.resize(size); | |||
| for (int32_t i = 0; i < size; i++) { | |||
| forward_indexes[i] = static_cast<int>(onnx_node_attr.ints(i)); | |||
| } | |||
| } else if (attribute_name == "output_dim") { | |||
| output_dim = static_cast<int>(onnx_node_attr.i()); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported attribute in splice " << attribute_name; | |||
| return nullptr; | |||
| } | |||
| } | |||
| primitive->Init(context, forward_indexes, output_dim); | |||
| return primitive.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSpliceParser("Splice", new OnnxSpliceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxSpliceParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxSpliceParser() : OnnxNodeParser("Splice") {} | |||
| ~OnnxSpliceParser() override = default; | |||
| ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H | |||