| @@ -39,11 +39,11 @@ class Arithmetic : public PrimitiveC { | |||
| } | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| bool Broadcasting() { return this->broadcasting_; } | |||
| int NDims() { return this->ndim_; } | |||
| std::vector<int> InShape0() { return this->in_shape0_; } | |||
| std::vector<int> InShape1() { return this->in_shape1_; } | |||
| std::vector<int> OutputShape() { return this->out_shape_; } | |||
| bool Broadcasting() const { return this->broadcasting_; } | |||
| int NDims() const { return this->ndim_; } | |||
| std::vector<int> InShape0() const { return this->in_shape0_; } | |||
| std::vector<int> InShape1() const { return this->in_shape1_; } | |||
| std::vector<int> OutputShape() const { return this->out_shape_; } | |||
| protected: | |||
| bool broadcasting_ = false; | |||
| @@ -21,20 +21,20 @@ | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/arithmetic.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Eltwise : public PrimitiveC { | |||
| class Eltwise : public Arithmetic { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Eltwise, PrimitiveC); | |||
| MS_DECLARE_PARENT(Eltwise, Arithmetic); | |||
| Eltwise() = default; | |||
| explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| explicit Eltwise(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| void SetMode(int mode); | |||
| #else | |||
| Eltwise() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int GetMode() const; | |||
| @@ -1,46 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/add.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (arithmetic_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||
| arithmetic_param->activation_type_ = | |||
| reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| } | |||
| Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,8 +13,13 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/add.h" | |||
| #include "src/ops/sub.h" | |||
| #include "src/ops/mul.h" | |||
| #include "src/ops/div.h" | |||
| #include "src/ops/eltwise.h" | |||
| #include "src/ops/greater_equal.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/tensor.h" | |||
| #include "src/ops/primitive_c.h" | |||
| @@ -22,27 +27,98 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ArithmeticParameter *PopulateArithmeticCommonPara(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(param, 0, sizeof(ArithmeticParameter)); | |||
| param->op_parameter_.type_ = primitive->Type(); | |||
| param->broadcasting_ = reinterpret_cast<const lite::Arithmetic *>(primitive)->Broadcasting(); | |||
| param->ndim_ = reinterpret_cast<const lite::Arithmetic *>(primitive)->NDims(); | |||
| param->activation_type_ = 0; | |||
| auto tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->InShape0(); | |||
| memcpy(param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->InShape1(); | |||
| memcpy(param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->OutputShape(); | |||
| memcpy(param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| return param; | |||
| } | |||
| OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (arithmetic_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||
| return nullptr; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||
| return nullptr; | |||
| } | |||
| param->activation_type_ = reinterpret_cast<const mindspore::lite::Add *>(primitive)->GetActivationType(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||
| return nullptr; | |||
| } | |||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||
| param->activation_type_ = reinterpret_cast<const mindspore::lite::Sub *>(primitive)->GetActivationType(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| arithmetic_param->activation_type_ = 0; | |||
| OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||
| return nullptr; | |||
| } | |||
| param->activation_type_ = reinterpret_cast<const mindspore::lite::Mul *>(primitive)->GetActivationType(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||
| return nullptr; | |||
| } | |||
| param->activation_type_ = reinterpret_cast<const mindspore::lite::Div *>(primitive)->GetActivationType(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; | |||
| return nullptr; | |||
| } | |||
| auto eltwise = reinterpret_cast<const mindspore::lite::Eltwise *>(primitive); | |||
| switch (eltwise->GetMode()) { | |||
| case schema::EltwiseMode_PROD: | |||
| param->op_parameter_.type_ = schema::PrimitiveType_Mul; | |||
| break; | |||
| case schema::EltwiseMode_SUM: | |||
| param->op_parameter_.type_ = schema::PrimitiveType_Add; | |||
| break; | |||
| case schema::EltwiseMode_MAXIMUM: | |||
| param->op_parameter_.type_ = schema::PrimitiveType_Maximum; | |||
| break; | |||
| default: | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic); | |||
| @@ -51,6 +127,7 @@ Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); | |||
| Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); | |||
| Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); | |||
| Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); | |||
| Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); | |||
| Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); | |||
| Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); | |||
| Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); | |||
| @@ -58,5 +135,10 @@ Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithme | |||
| Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | |||
| Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); | |||
| Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); | |||
| Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter); | |||
| Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter); | |||
| Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter); | |||
| Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter); | |||
| Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,47 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/div.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (arithmetic_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||
| arithmetic_param->activation_type_ = | |||
| reinterpret_cast<mindspore::lite::Div *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| } | |||
| Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,52 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/eltwise.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (arithmetic_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||
| auto eltwise = reinterpret_cast<mindspore::lite::Eltwise *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| switch (eltwise->GetMode()) { | |||
| case schema::EltwiseMode_PROD: | |||
| arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul; | |||
| break; | |||
| case schema::EltwiseMode_SUM: | |||
| arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add; | |||
| break; | |||
| case schema::EltwiseMode_MAXIMUM: | |||
| arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum; | |||
| break; | |||
| default: | |||
| free(arithmetic_param); | |||
| return nullptr; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| } | |||
| Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,48 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/mul.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (arithmetic_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||
| arithmetic_param->activation_type_ = | |||
| reinterpret_cast<mindspore::lite::Mul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| } | |||
| Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,47 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/sub.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (arithmetic_param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||
| arithmetic_param->activation_type_ = | |||
| reinterpret_cast<mindspore::lite::Sub *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| } | |||
| Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -15,6 +15,8 @@ | |||
| */ | |||
| #include "src/runtime/kernel/arm/int8/arithmetic_int8.h" | |||
| #include "src/runtime/kernel/arm/int8/add_int8.h" | |||
| #include "src/runtime/kernel/arm/int8/mul_int8.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| @@ -27,11 +29,14 @@ using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::lite::RET_PARAM_INVALID; | |||
| using mindspore::schema::PrimitiveType_Add; | |||
| using mindspore::schema::PrimitiveType_Eltwise; | |||
| using mindspore::schema::PrimitiveType_Equal; | |||
| using mindspore::schema::PrimitiveType_Greater; | |||
| using mindspore::schema::PrimitiveType_GreaterEqual; | |||
| using mindspore::schema::PrimitiveType_Less; | |||
| using mindspore::schema::PrimitiveType_LessEqual; | |||
| using mindspore::schema::PrimitiveType_Mul; | |||
| using mindspore::schema::PrimitiveType_NotEqual; | |||
| namespace mindspore::kernel { | |||
| @@ -159,11 +164,15 @@ kernel::LiteKernel *CpuArithmeticInt8KernelCreator(const std::vector<lite::Tenso | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input parameter is null!"; | |||
| return nullptr; | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (desc.type == PrimitiveType_Eltwise && static_cast<schema::PrimitiveType>(parameter->type_) == PrimitiveType_Add) { | |||
| kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| } else if (desc.type == PrimitiveType_Eltwise && | |||
| static_cast<schema::PrimitiveType>(parameter->type_) == PrimitiveType_Mul) { | |||
| kernel = new (std::nothrow) MulInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| } else { | |||
| kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| auto kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create ArithmeticInt8CPUKernel failed, name: " << parameter->name_; | |||
| free(parameter); | |||
| @@ -185,5 +194,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Less, CpuArithmeticInt8KernelCre | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LessEqual, CpuArithmeticInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Greater, CpuArithmeticInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GreaterEqual, CpuArithmeticInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Eltwise, CpuArithmeticInt8KernelCreator) | |||
| } // namespace mindspore::kernel | |||