Merge pull request !7570 from yeyunpeng2020/primitivetags/v1.1.0
| @@ -19,6 +19,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| #include "include/version.h" | #include "include/version.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -31,7 +32,12 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) { | |||||
| } | } | ||||
| auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i); | auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i); | ||||
| auto src_prim = c_node->primitive(); | auto src_prim = c_node->primitive(); | ||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | ||||
| #else | |||||
| auto primitive = const_cast<schema::Primitive *>(src_prim); | |||||
| node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); | |||||
| #endif | |||||
| if (node->primitive_ == nullptr) { | if (node->primitive_ == nullptr) { | ||||
| MS_LOG(ERROR) << "unpack primitive == nullptr!"; | MS_LOG(ERROR) << "unpack primitive == nullptr!"; | ||||
| delete node; | delete node; | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/ops/abs.h" | #include "src/ops/abs.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -27,6 +28,9 @@ int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *AbsCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Abs>(primitive); } | |||||
| Registry AbsRegistry(schema::PrimitiveType_Abs, AbsCreator); | |||||
| #endif | #endif | ||||
| Registry AbsParameterRegistry(schema::PrimitiveType_Abs, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/activation.h" | #include "src/ops/activation.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/activation.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -80,6 +82,30 @@ int Activation::GetType() const { return this->primitive_->value_as_Activation() | |||||
| float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); } | float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); } | ||||
| float Activation::GetMinVal() const { return this->primitive_->value_as_Activation()->min_val(); } | float Activation::GetMinVal() const { return this->primitive_->value_as_Activation()->min_val(); } | ||||
| float Activation::GetMaxVal() const { return this->primitive_->value_as_Activation()->max_val(); } | float Activation::GetMaxVal() const { return this->primitive_->value_as_Activation()->max_val(); } | ||||
| PrimitiveC *ActivationCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<Activation>(primitive); | |||||
| } | |||||
| Registry ActivationRegistry(schema::PrimitiveType_Activation, ActivationCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter))); | |||||
| if (act_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ActivationParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(act_param, 0, sizeof(ActivationParameter)); | |||||
| act_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto activation = | |||||
| reinterpret_cast<mindspore::lite::Activation *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| act_param->type_ = static_cast<int>(activation->GetType()); | |||||
| act_param->alpha_ = activation->GetAlpha(); | |||||
| act_param->min_val_ = activation->GetMinVal(); | |||||
| act_param->max_val_ = activation->GetMaxVal(); | |||||
| return reinterpret_cast<OpParameter *>(act_param); | |||||
| } | |||||
| Registry ActivationParameterRegistry(schema::PrimitiveType_Activation, PopulateActivationParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/activation_grad.h" | #include "src/ops/activation_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -74,6 +76,11 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat | |||||
| } | } | ||||
| int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); } | int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); } | ||||
| float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); } | float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); } | ||||
| PrimitiveC *ActivationGradCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<ActivationGrad>(primitive); | |||||
| } | |||||
| Registry ActivationGradRegistry(schema::PrimitiveType_ActivationGrad, ActivationGradCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/adam.h" | #include "src/ops/adam.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -62,6 +64,9 @@ int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *AdamCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adam>(primitive); } | |||||
| Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator); | |||||
| #endif | #endif | ||||
| int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/add.h" | #include "src/ops/add.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/arithmetic_common.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -71,6 +73,31 @@ int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| } | } | ||||
| int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); } | int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); } | ||||
| PrimitiveC *AddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Add>(primitive); } | |||||
| Registry AddRegistry(schema::PrimitiveType_Add, AddCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||||
| arithmetic_param->activation_type_ = | |||||
| reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| } | |||||
| Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/addn.h" | #include "src/ops/addn.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -62,8 +64,22 @@ int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| } | } | ||||
| int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); } | int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); } | ||||
| PrimitiveC *AddNCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<AddN>(primitive); } | |||||
| Registry AddNRegistry(schema::PrimitiveType_AddN, AddNCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateAddNParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *addn_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (addn_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(addn_param, 0, sizeof(OpParameter)); | |||||
| addn_param->type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(addn_param); | |||||
| } | |||||
| Registry AddNParameterRegistry(schema::PrimitiveType_AddN, PopulateAddNParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kLeastInputNum = 2; | constexpr int kLeastInputNum = 2; | ||||
| } | } | ||||
| @@ -14,6 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/apply_momentum.h" | #include "src/ops/apply_momentum.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -67,6 +69,11 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *ApplyMomentumCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<ApplyMomentum>(primitive); | |||||
| } | |||||
| Registry ApplyMomentumRegistry(schema::PrimitiveType_ApplyMomentum, ApplyMomentumCreator); | |||||
| #endif | #endif | ||||
| int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/argmax.h" | #include "src/ops/argmax.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/arg_min_max_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -52,8 +55,29 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK() | |||||
| bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); } | bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); } | ||||
| int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); } | int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); } | ||||
| PrimitiveC *ArgMaxCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<ArgMax>(primitive); } | |||||
| Registry ArgMaxRegistry(schema::PrimitiveType_ArgMax, ArgMaxCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateArgMaxParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); | |||||
| if (arg_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arg_param, 0, sizeof(ArgMinMaxParameter)); | |||||
| arg_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::ArgMax *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| arg_param->axis_ = param->GetAxis(); | |||||
| arg_param->topk_ = param->GetTopK(); | |||||
| arg_param->axis_type_ = param->GetAxisType(); | |||||
| arg_param->out_value_ = param->GetOutMaxValue(); | |||||
| arg_param->keep_dims_ = param->GetKeepDims(); | |||||
| return reinterpret_cast<OpParameter *>(arg_param); | |||||
| } | |||||
| Registry ArgMaxParameterRegistry(schema::PrimitiveType_ArgMax, PopulateArgMaxParameter); | |||||
| int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/argmin.h" | #include "src/ops/argmin.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/arg_min_max_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -52,8 +55,29 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK() | |||||
| bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); } | bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); } | ||||
| int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); } | int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); } | ||||
| PrimitiveC *ArgMinCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<ArgMin>(primitive); } | |||||
| Registry ArgMinRegistry(schema::PrimitiveType_ArgMin, ArgMinCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateArgMinParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter))); | |||||
| if (arg_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arg_param, 0, sizeof(ArgMinMaxParameter)); | |||||
| arg_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::ArgMin *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| arg_param->axis_ = param->GetAxis(); | |||||
| arg_param->topk_ = param->GetTopK(); | |||||
| arg_param->axis_type_ = param->GetAxisType(); | |||||
| arg_param->out_value_ = param->GetOutMaxValue(); | |||||
| arg_param->keep_dims_ = param->GetKeepDims(); | |||||
| return reinterpret_cast<OpParameter *>(arg_param); | |||||
| } | |||||
| Registry ArgMinParameterRegistry(schema::PrimitiveType_ArgMin, PopulateArgMinParameter); | |||||
| int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -21,6 +21,29 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||||
| arithmetic_param->activation_type_ = 0; | |||||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| } | |||||
| int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (inputs_.size() != kDoubleNum) { | if (inputs_.size() != kDoubleNum) { | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "nnacl/arithmetic_common.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -51,6 +52,8 @@ class Arithmetic : public PrimitiveC { | |||||
| std::vector<int> in_shape1_; | std::vector<int> in_shape1_; | ||||
| std::vector<int> out_shape_; | std::vector<int> out_shape_; | ||||
| }; | }; | ||||
| OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "nnacl/arithmetic_self_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -17,9 +17,21 @@ | |||||
| #include "src/ops/arithmetic_self.h" | #include "src/ops/arithmetic_self.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticSelfParameter *arithmetic_self_param = | |||||
| reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter))); | |||||
| if (arithmetic_self_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter)); | |||||
| arithmetic_self_param->op_parameter_.type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_self_param); | |||||
| } | |||||
| int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "nnacl/arithmetic_self_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -37,6 +38,7 @@ class ArithmeticSelf : public PrimitiveC { | |||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| }; | }; | ||||
| OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/ops/assign.h" | #include "src/ops/assign.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -56,6 +58,9 @@ int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *AssignCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Assign>(primitive); } | |||||
| Registry AssignRegistry(schema::PrimitiveType_Assign, AssignCreator); | |||||
| #endif | #endif | ||||
| int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/batch_norm.h" | #include "src/ops/batch_norm.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/batchnorm_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -60,6 +63,28 @@ int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| } | } | ||||
| float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); } | float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); } | ||||
| PrimitiveC *BatchNormCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<BatchNorm>(primitive); | |||||
| } | |||||
| Registry BatchNormRegistry(schema::PrimitiveType_BatchNorm, BatchNormCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) { | |||||
| const auto param = | |||||
| reinterpret_cast<mindspore::lite::BatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||||
| if (batch_norm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(batch_norm_param, 0, sizeof(BatchNormParameter)); | |||||
| batch_norm_param->op_parameter_.type_ = primitive->Type(); | |||||
| batch_norm_param->epsilon_ = param->GetEpsilon(); | |||||
| batch_norm_param->fused_ = false; | |||||
| return reinterpret_cast<OpParameter *>(batch_norm_param); | |||||
| } | |||||
| Registry BatchNormParameterRegistry(schema::PrimitiveType_BatchNorm, PopulateBatchNorm); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,6 +20,9 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/batch_to_space.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -66,7 +69,49 @@ std::vector<int> BatchToSpace::GetCrops() const { | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| PrimitiveC *BatchToSpaceCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<BatchToSpace>(primitive); | |||||
| } | |||||
| Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| BatchToSpaceParameter *batch_space_param = | |||||
| reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter))); | |||||
| if (batch_space_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(batch_space_param, 0, sizeof(BatchToSpaceParameter)); | |||||
| batch_space_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| auto block_shape = param->GetBlockShape(); | |||||
| if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { | |||||
| MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; | |||||
| free(batch_space_param); | |||||
| return nullptr; | |||||
| } | |||||
| auto crops = param->GetCrops(); | |||||
| if (crops.size() != BATCH_TO_SPACE_CROPS_SIZE) { | |||||
| MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE; | |||||
| free(batch_space_param); | |||||
| return nullptr; | |||||
| } | |||||
| for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { | |||||
| batch_space_param->block_shape_[i] = block_shape[i]; | |||||
| } | |||||
| for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { | |||||
| batch_space_param->crops_[i] = crops[i]; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(batch_space_param); | |||||
| } | |||||
| Registry BatchToSpaceParameterRegistry(schema::PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter); | |||||
| Registry BatchToSpaceNDParameterRegistry(schema::PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kBatchToSpaceOutputNum = 1; | constexpr int kBatchToSpaceOutputNum = 1; | ||||
| constexpr int kBatchToSpaceInputNum = 1; | constexpr int kBatchToSpaceInputNum = 1; | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/bias_add.h" | #include "src/ops/bias_add.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "nnacl/arithmetic_common.h" | |||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -78,6 +80,22 @@ std::vector<int> BiasAdd::GetAxis() const { | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); } | |||||
| Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateBiasAddParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| } | |||||
| Registry BiasAddParameterRegistry(schema::PrimitiveType_BiasAdd, PopulateBiasAddParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/bias_grad.h" | #include "src/ops/bias_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -74,6 +76,11 @@ std::vector<int> BiasGrad::GetAxis() const { | |||||
| auto fb_vector = this->primitive_->value_as_BiasGrad()->axis(); | auto fb_vector = this->primitive_->value_as_BiasGrad()->axis(); | ||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| PrimitiveC *BiasGradCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<BiasGrad>(primitive); | |||||
| } | |||||
| Registry BiasGradRegistry(schema::PrimitiveType_BiasGrad, BiasGradCreator); | |||||
| #endif | #endif | ||||
| int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/bn_grad.h" | #include "src/ops/bn_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/broadcast_to.h" | #include "src/ops/broadcast_to.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/broadcast_to.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -50,7 +53,32 @@ std::vector<int> BroadcastTo::GetDstShape() const { | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| PrimitiveC *BroadcastToCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<BroadcastTo>(primitive); | |||||
| } | |||||
| Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateBroadcastToParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| BroadcastToParameter *broadcast_param = | |||||
| reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter))); | |||||
| if (broadcast_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BroadcastToParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(broadcast_param, 0, sizeof(BroadcastToParameter)); | |||||
| auto param = reinterpret_cast<mindspore::lite::BroadcastTo *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| broadcast_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto dst_shape = param->GetDstShape(); | |||||
| broadcast_param->shape_size_ = dst_shape.size(); | |||||
| for (size_t i = 0; i < broadcast_param->shape_size_; ++i) { | |||||
| broadcast_param->shape_[i] = dst_shape[i]; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(broadcast_param); | |||||
| } | |||||
| Registry BroadcastToParameterRegistry(schema::PrimitiveType_BroadcastTo, PopulateBroadcastToParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kBroadcastToInputNum = 1; | constexpr int kBroadcastToInputNum = 1; | ||||
| constexpr int kBroadcastToOutputNum = 1; | constexpr int kBroadcastToOutputNum = 1; | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/cast.h" | #include "src/ops/cast.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/cast.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -75,8 +78,26 @@ int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); } | int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); } | ||||
| int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); } | int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); } | ||||
| PrimitiveC *CastCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Cast>(primitive); } | |||||
| Registry CastRegistry(schema::PrimitiveType_Cast, CastCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateCastParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| CastParameter *cast_param = reinterpret_cast<CastParameter *>(malloc(sizeof(CastParameter))); | |||||
| if (cast_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CastParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(cast_param, 0, sizeof(CastParameter)); | |||||
| cast_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::Cast *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| cast_param->src_type_ = param->GetSrcT(); | |||||
| cast_param->dst_type_ = param->GetDstT(); | |||||
| return reinterpret_cast<OpParameter *>(cast_param); | |||||
| } | |||||
| Registry CastParameterRegistry(schema::PrimitiveType_Cast, PopulateCastParameter); | |||||
| int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/ceil.h" | |||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -21,6 +21,7 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic_self.h" | #include "src/ops/arithmetic_self.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -43,6 +44,7 @@ class Ceil : public ArithmeticSelf { | |||||
| } | } | ||||
| #endif | #endif | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/clip.h" | #include "src/ops/clip.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/clip.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -42,6 +45,24 @@ int Clip::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); } | float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); } | ||||
| float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); } | float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); } | ||||
| PrimitiveC *ClipCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Clip>(primitive); } | |||||
| Registry ClipRegistry(schema::PrimitiveType_Clip, ClipCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateClipParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ClipParameter *act_param = reinterpret_cast<ClipParameter *>(malloc(sizeof(ClipParameter))); | |||||
| if (act_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ClipParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(act_param, 0, sizeof(ClipParameter)); | |||||
| act_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto activation = reinterpret_cast<mindspore::lite::Clip *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| act_param->min_val_ = activation->GetMin(); | |||||
| act_param->max_val_ = activation->GetMax(); | |||||
| return reinterpret_cast<OpParameter *>(act_param); | |||||
| } | |||||
| Registry ClipParameterRegistry(schema::PrimitiveType_Clip, PopulateClipParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,8 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/concat_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -76,8 +78,26 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } | int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } | ||||
| int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); } | int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); } | ||||
| PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); } | |||||
| Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateConcatParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ConcatParameter *concat_param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter))); | |||||
| if (concat_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConcatParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(concat_param, 0, sizeof(ConcatParameter)); | |||||
| concat_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::Concat *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| concat_param->axis_ = param->GetAxis(); | |||||
| return reinterpret_cast<OpParameter *>(concat_param); | |||||
| } | |||||
| Registry ConcatParameterRegistry(schema::PrimitiveType_Concat, PopulateConcatParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kConcatOutputNum = 1; | constexpr int kConcatOutputNum = 1; | ||||
| } | } | ||||
| @@ -18,6 +18,8 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/constant_of_shape.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| namespace { | namespace { | ||||
| @@ -45,8 +47,29 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||||
| } | } | ||||
| float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } | float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } | ||||
| PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<ConstantOfShape>(primitive); | |||||
| } | |||||
| Registry ConstantOfShapeRegistry(schema::PrimitiveType_ConstantOfShape, ConstantOfShapeCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto attr = | |||||
| reinterpret_cast<mindspore::lite::ConstantOfShape *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| ConstantOfShapeParameter *param = | |||||
| reinterpret_cast<ConstantOfShapeParameter *>(malloc(sizeof(ConstantOfShapeParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(ConstantOfShapeParameter)); | |||||
| param->op_parameter_.type_ = primitive->Type(); | |||||
| param->value_ = attr->GetValue(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | |||||
| Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); | |||||
| int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| if (inputs_.size() != kShapeInputNum) { | if (inputs_.size() != kShapeInputNum) { | ||||
| MS_LOG(ERROR) << "inputs to ConstantOfShape operator should be 1, but " << inputs_.size() << " is given."; | MS_LOG(ERROR) << "inputs to ConstantOfShape operator should be 1, but " << inputs_.size() << " is given."; | ||||
| @@ -24,9 +24,10 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <float.h> | #include <float.h> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #endif | #endif | ||||
| #include "nnacl/conv_parameter.h" | |||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -320,7 +321,51 @@ int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dil | |||||
| bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); } | bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); } | ||||
| int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); } | int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); } | ||||
| PrimitiveC *Conv2DCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Conv2D>(primitive); } | |||||
| Registry Conv2DRegistry(schema::PrimitiveType_Conv2D, Conv2DCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| conv_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto conv_primitive = | |||||
| reinterpret_cast<mindspore::lite::Conv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| conv_param->kernel_h_ = conv_primitive->GetKernelH(); | |||||
| conv_param->kernel_w_ = conv_primitive->GetKernelW(); | |||||
| conv_param->group_ = conv_primitive->GetGroup(); | |||||
| conv_param->stride_h_ = conv_primitive->GetStrideH(); | |||||
| conv_param->stride_w_ = conv_primitive->GetStrideW(); | |||||
| auto conv2d_lite_primitive = (lite::Conv2D *)primitive; | |||||
| conv_param->pad_u_ = conv2d_lite_primitive->PadUp(); | |||||
| conv_param->pad_d_ = conv2d_lite_primitive->PadDown(); | |||||
| conv_param->pad_l_ = conv2d_lite_primitive->PadLeft(); | |||||
| conv_param->pad_r_ = conv2d_lite_primitive->PadRight(); | |||||
| conv_param->dilation_h_ = conv_primitive->GetDilateH(); | |||||
| conv_param->dilation_w_ = conv_primitive->GetDilateW(); | |||||
| conv_param->input_channel_ = conv_primitive->GetChannelIn(); | |||||
| conv_param->output_channel_ = conv_primitive->GetChannelOut(); | |||||
| conv_param->group_ = conv_primitive->GetGroup(); | |||||
| auto act_type = conv_primitive->GetActivationType(); | |||||
| switch (act_type) { | |||||
| case schema::ActivationType_RELU: | |||||
| conv_param->act_type_ = ActType_Relu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| break; | |||||
| default: | |||||
| conv_param->act_type_ = ActType_No; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| } | |||||
| Registry Conv2DParameterRegistry(schema::PrimitiveType_Conv2D, PopulateConvParameter); | |||||
| void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { | void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| int kernel_w = GetKernelW(); | int kernel_w = GetKernelW(); | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/ops/conv2d_grad_filter.h" | #include "src/ops/conv2d_grad_filter.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -176,6 +177,10 @@ int Conv2DGradFilter::GetActivationType() const { | |||||
| return this->primitive_->value_as_Conv2DGradFilter()->activationType(); | return this->primitive_->value_as_Conv2DGradFilter()->activationType(); | ||||
| } | } | ||||
| PrimitiveC *Conv2DGradFilterCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<Conv2DGradFilter>(primitive); | |||||
| } | |||||
| Registry conv2DGradFilterRegistry(schema::PrimitiveType_Conv2DGradFilter, Conv2DGradFilterCreator); | |||||
| #endif | #endif | ||||
| int Conv2DGradFilter::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int Conv2DGradFilter::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/conv2d_grad_input.h" | #include "src/ops/conv2d_grad_input.h" | ||||
| #include "src/ops/group_conv2d_grad_input.h" | #include "src/ops/group_conv2d_grad_input.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -178,6 +180,10 @@ int Conv2DGradInput::GetActivationType() const { | |||||
| return this->primitive_->value_as_Conv2DGradInput()->activationType(); | return this->primitive_->value_as_Conv2DGradInput()->activationType(); | ||||
| } | } | ||||
| PrimitiveC *Conv2DGradInputCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<Conv2DGradInput>(primitive); | |||||
| } | |||||
| Registry Conv2DGradInputRegistry(schema::PrimitiveType_Conv2DGradInput, Conv2DGradInputCreator); | |||||
| #endif | #endif | ||||
| int Conv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int Conv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/cos.h" | #include "src/ops/cos.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -27,6 +29,10 @@ int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *CosCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Cos>(primitive); } | |||||
| Registry CosRegistry(schema::PrimitiveType_Cos, CosCreator); | |||||
| #endif | #endif | ||||
| Registry CosParameterRegistry(schema::PrimitiveType_Cos, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/crop.h" | #include "src/ops/crop.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/crop_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -51,7 +54,33 @@ std::vector<int64_t> Crop::GetOffsets() const { | |||||
| return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| PrimitiveC *CropCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Crop>(primitive); } | |||||
| Registry CropRegistry(schema::PrimitiveType_Crop, CropCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto param = reinterpret_cast<mindspore::lite::Crop *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| auto param_offset = param->GetOffsets(); | |||||
| if (param_offset.size() > CROP_OFFSET_MAX_SIZE) { | |||||
| MS_LOG(ERROR) << "crop_param offset size(" << param_offset.size() << ") should <= " << CROP_OFFSET_MAX_SIZE; | |||||
| return nullptr; | |||||
| } | |||||
| CropParameter *crop_param = reinterpret_cast<CropParameter *>(malloc(sizeof(CropParameter))); | |||||
| if (crop_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CropParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(crop_param, 0, sizeof(CropParameter)); | |||||
| crop_param->op_parameter_.type_ = primitive->Type(); | |||||
| crop_param->axis_ = param->GetAxis(); | |||||
| crop_param->offset_size_ = param_offset.size(); | |||||
| for (size_t i = 0; i < param_offset.size(); ++i) { | |||||
| crop_param->offset_[i] = param_offset[i]; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(crop_param); | |||||
| } | |||||
| Registry CropParameterRegistry(schema::PrimitiveType_Crop, PopulateCropParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kCropOutputNum = 1; | constexpr int kCropOutputNum = 1; | ||||
| constexpr int kCropInputNum = 2; | constexpr int kCropInputNum = 2; | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/common/string_util.h" | #include "src/common/string_util.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -31,8 +33,25 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *CustomExtractFeaturesCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<CustomExtractFeatures>(primitive); | |||||
| } | |||||
| Registry CustomExtractFeaturesRegistry(schema::PrimitiveType_CustomExtractFeatures, CustomExtractFeaturesCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateExtractFeaturesParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->Type(); | |||||
| return param; | |||||
| } | |||||
| Registry CustomExtractFeaturesParameterRegistry(schema::PrimitiveType_CustomExtractFeatures, | |||||
| PopulateExtractFeaturesParameter); | |||||
| int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.at(0); | auto input = inputs_.at(0); | ||||
| auto output0 = outputs_.at(0); | auto output0 = outputs_.at(0); | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/common/string_util.h" | #include "src/common/string_util.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -30,7 +32,25 @@ int CustomNormalize::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *CustomNormalizeCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<CustomNormalize>(primitive); | |||||
| } | |||||
| Registry CustomNormalizeRegistry(schema::PrimitiveType_CustomNormalize, CustomNormalizeCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateCustomNormalizeParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->Type(); | |||||
| return param; | |||||
| } | |||||
| Registry CustomNormalizeParameterRegistry(schema::PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter); | |||||
| int CustomNormalize::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int CustomNormalize::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.at(0); | auto input = inputs_.at(0); | ||||
| auto output = outputs_.at(0); | auto output = outputs_.at(0); | ||||
| @@ -15,6 +15,9 @@ | |||||
| */ | */ | ||||
| #include "src/ops/custom_predict.h" | #include "src/ops/custom_predict.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/predict_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -45,7 +48,27 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *CustomPredictCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<CustomPredict>(primitive); | |||||
| } | |||||
| Registry CustomPredictRegistry(schema::PrimitiveType_CustomPredict, CustomPredictCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PredictParameter)); | |||||
| param->op_parameter_.type_ = primitive->Type(); | |||||
| auto prim = reinterpret_cast<mindspore::lite::CustomPredict *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| param->output_num = prim->GetOutputNum(); | |||||
| param->weight_threshold = prim->GetWeightThreshold(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | |||||
| Registry CustomPredictParameterRegistry(schema::PrimitiveType_CustomPredict, PopulateCustomPredictParameter); | |||||
| int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.at(0); | auto input = inputs_.at(0); | ||||
| auto output0 = outputs_.at(0); | auto output0 = outputs_.at(0); | ||||
| @@ -25,6 +25,9 @@ | |||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #endif | #endif | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -295,7 +298,51 @@ int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()- | |||||
| bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); } | bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); } | ||||
| int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); } | int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); } | ||||
| PrimitiveC *DeConv2DCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<DeConv2D>(primitive); | |||||
| } | |||||
| Registry DeConv2DRegistry(schema::PrimitiveType_DeConv2D, DeConv2DCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateDeconvParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| conv_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto conv_primitive = | |||||
| reinterpret_cast<mindspore::lite::DeConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| conv_param->kernel_h_ = conv_primitive->GetKernelH(); | |||||
| conv_param->kernel_w_ = conv_primitive->GetKernelW(); | |||||
| conv_param->stride_h_ = conv_primitive->GetStrideH(); | |||||
| conv_param->stride_w_ = conv_primitive->GetStrideW(); | |||||
| auto deconv_lite_primitive = (lite::DeConv2D *)primitive; | |||||
| conv_param->pad_u_ = deconv_lite_primitive->PadUp(); | |||||
| conv_param->pad_d_ = deconv_lite_primitive->PadDown(); | |||||
| conv_param->pad_l_ = deconv_lite_primitive->PadLeft(); | |||||
| conv_param->pad_r_ = deconv_lite_primitive->PadRight(); | |||||
| conv_param->dilation_h_ = conv_primitive->GetDilateH(); | |||||
| conv_param->dilation_w_ = conv_primitive->GetDilateW(); | |||||
| auto act_type = conv_primitive->GetActivationType(); | |||||
| switch (act_type) { | |||||
| case schema::ActivationType_RELU: | |||||
| conv_param->act_type_ = ActType_Relu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| break; | |||||
| default: | |||||
| conv_param->act_type_ = ActType_No; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| } | |||||
| Registry DeConv2DParameterRegistry(schema::PrimitiveType_DeConv2D, PopulateDeconvParameter); | |||||
| int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/dedepthwise_conv2d.h" | #include "src/ops/dedepthwise_conv2d.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -109,7 +112,51 @@ int DeDepthwiseConv2D::GetActivationType() const { | |||||
| return this->primitive_->value_as_DeDepthwiseConv2D()->activationType(); | return this->primitive_->value_as_DeDepthwiseConv2D()->activationType(); | ||||
| } | } | ||||
| PrimitiveC *DeDepthwiseConv2DCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<DeDepthwiseConv2D>(primitive); | |||||
| } | |||||
| Registry DeDepthwiseConv2DRegistry(schema::PrimitiveType_DeDepthwiseConv2D, DeDepthwiseConv2DCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateDeconvDwParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| conv_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto conv_primitive = | |||||
| reinterpret_cast<mindspore::lite::DeDepthwiseConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| conv_param->kernel_h_ = conv_primitive->GetKernelH(); | |||||
| conv_param->kernel_w_ = conv_primitive->GetKernelW(); | |||||
| conv_param->stride_h_ = conv_primitive->GetStrideH(); | |||||
| conv_param->stride_w_ = conv_primitive->GetStrideW(); | |||||
| auto deconvdw_lite_primitive = (mindspore::lite::DeDepthwiseConv2D *)primitive; | |||||
| conv_param->pad_u_ = deconvdw_lite_primitive->PadUp(); | |||||
| conv_param->pad_d_ = deconvdw_lite_primitive->PadDown(); | |||||
| conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft(); | |||||
| conv_param->pad_r_ = deconvdw_lite_primitive->PadRight(); | |||||
| conv_param->dilation_h_ = conv_primitive->GetDilateH(); | |||||
| conv_param->dilation_w_ = conv_primitive->GetDilateW(); | |||||
| auto act_type = conv_primitive->GetActivationType(); | |||||
| switch (act_type) { | |||||
| case schema::ActivationType_RELU: | |||||
| conv_param->act_type_ = ActType_Relu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| break; | |||||
| default: | |||||
| conv_param->act_type_ = ActType_No; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| } | |||||
| Registry DeDepthwiseConv2DParameterRegistry(schema::PrimitiveType_DeDepthwiseConv2D, PopulateDeconvDwParameter); | |||||
| int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { | if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { | ||||
| MS_LOG(ERROR) << "inputs number is invalid"; | MS_LOG(ERROR) << "inputs number is invalid"; | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -52,6 +54,9 @@ int Depend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *DependCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Depend>(primitive); } | |||||
| Registry DependRegistry(schema::PrimitiveType_Depend, DependCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/depth_to_space.h" | #include "src/ops/depth_to_space.h" | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/depth_to_space_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -42,7 +45,29 @@ int DepthToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu | |||||
| int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); } | int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); } | ||||
| int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); } | int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); } | ||||
| PrimitiveC *DepthToSpaceCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<DepthToSpace>(primitive); | |||||
| } | |||||
| Registry DepthToSpaceRegistry(schema::PrimitiveType_DepthToSpace, DepthToSpaceCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateDepthToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| DepthToSpaceParameter *depth_space_param = | |||||
| reinterpret_cast<DepthToSpaceParameter *>(malloc(sizeof(DepthToSpaceParameter))); | |||||
| if (depth_space_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(depth_space_param, 0, sizeof(DepthToSpaceParameter)); | |||||
| auto param = reinterpret_cast<mindspore::lite::DepthToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| depth_space_param->op_parameter_.type_ = primitive->Type(); | |||||
| depth_space_param->block_size_ = param->GetBlockSize(); | |||||
| return reinterpret_cast<OpParameter *>(depth_space_param); | |||||
| } | |||||
| Registry DepthToSpaceParameterRegistry(schema::PrimitiveType_DepthToSpace, PopulateDepthToSpaceParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kDepthToSpaceOutputNum = 1; | constexpr int kDepthToSpaceOutputNum = 1; | ||||
| constexpr int kDepthToSpaceInputNum = 1; | constexpr int kDepthToSpaceInputNum = 1; | ||||
| @@ -21,6 +21,9 @@ | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #endif | #endif | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -191,7 +194,54 @@ int DepthwiseConv2D::GetActivationType() const { | |||||
| return this->primitive_->value_as_DepthwiseConv2D()->activationType(); | return this->primitive_->value_as_DepthwiseConv2D()->activationType(); | ||||
| } | } | ||||
| PrimitiveC *DepthWiseConv2DCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<DepthwiseConv2D>(primitive); | |||||
| } | |||||
| Registry DepthWiseConv2DRegistry(schema::PrimitiveType_DepthwiseConv2D, DepthWiseConv2DCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateConvDwParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| conv_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto conv_primitive = | |||||
| reinterpret_cast<mindspore::lite::DepthwiseConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| conv_param->kernel_h_ = conv_primitive->GetKernelH(); | |||||
| conv_param->kernel_w_ = conv_primitive->GetKernelW(); | |||||
| conv_param->stride_h_ = conv_primitive->GetStrideH(); | |||||
| conv_param->stride_w_ = conv_primitive->GetStrideW(); | |||||
| auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive; | |||||
| conv_param->pad_u_ = convdw_lite_primitive->PadUp(); | |||||
| conv_param->pad_d_ = convdw_lite_primitive->PadDown(); | |||||
| conv_param->pad_l_ = convdw_lite_primitive->PadLeft(); | |||||
| conv_param->pad_r_ = convdw_lite_primitive->PadRight(); | |||||
| conv_param->input_channel_ = convdw_lite_primitive->GetInputChannel(); | |||||
| conv_param->dilation_h_ = conv_primitive->GetDilateH(); | |||||
| conv_param->dilation_w_ = conv_primitive->GetDilateW(); | |||||
| auto act_type = conv_primitive->GetActivationType(); | |||||
| switch (act_type) { | |||||
| case schema::ActivationType_RELU: | |||||
| conv_param->act_type_ = ActType_Relu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| break; | |||||
| default: | |||||
| conv_param->act_type_ = ActType_No; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| } | |||||
| Registry DepthwiseConv2DParameterRegistry(schema::PrimitiveType_DepthwiseConv2D, PopulateConvDwParameter); | |||||
| int DepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int DepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { | if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { | ||||
| MS_LOG(ERROR) << "inputs number is invalid"; | MS_LOG(ERROR) << "inputs number is invalid"; | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/detection_post_process.h" | #include "src/ops/detection_post_process.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/detection_post_process_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -139,7 +142,38 @@ bool DetectionPostProcess::GetOutQuantized() const { | |||||
| return this->primitive_->value_as_DetectionPostProcess()->OutQuantized(); | return this->primitive_->value_as_DetectionPostProcess()->OutQuantized(); | ||||
| } | } | ||||
| PrimitiveC *DetectionPostProcessCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<DetectionPostProcess>(primitive); | |||||
| } | |||||
| Registry DetectionPostProcessRegistry(schema::PrimitiveType_DetectionPostProcess, DetectionPostProcessCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateDetectionPostProcessParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| DetectionPostProcessParameter *detection_post_process_parameter = | |||||
| reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter))); | |||||
| if (detection_post_process_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter)); | |||||
| detection_post_process_parameter->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::DetectionPostProcess *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| detection_post_process_parameter->h_scale_ = param->GetHScale(); | |||||
| detection_post_process_parameter->w_scale_ = param->GetWScale(); | |||||
| detection_post_process_parameter->x_scale_ = param->GetXScale(); | |||||
| detection_post_process_parameter->y_scale_ = param->GetYScale(); | |||||
| detection_post_process_parameter->nms_iou_threshold_ = param->GetNmsIouThreshold(); | |||||
| detection_post_process_parameter->nms_score_threshold_ = param->GetNmsScoreThreshold(); | |||||
| detection_post_process_parameter->max_detections_ = param->GetMaxDetections(); | |||||
| detection_post_process_parameter->detections_per_class_ = param->GetDetectionsPerClass(); | |||||
| detection_post_process_parameter->max_classes_per_detection_ = param->GetMaxClassesPerDetection(); | |||||
| detection_post_process_parameter->num_classes_ = param->GetNumClasses(); | |||||
| detection_post_process_parameter->use_regular_nms_ = param->GetUseRegularNms(); | |||||
| return reinterpret_cast<OpParameter *>(detection_post_process_parameter); | |||||
| } | |||||
| Registry DetectionPostProcessParameterRegistry(schema::PrimitiveType_DetectionPostProcess, | |||||
| PopulateDetectionPostProcessParameter); | |||||
| namespace { | namespace { | ||||
| constexpr int kDetectionPostProcessOutputNum = 4; | constexpr int kDetectionPostProcessOutputNum = 4; | ||||
| constexpr int kDetectionPostProcessInputNum = 3; | constexpr int kDetectionPostProcessInputNum = 3; | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/div.h" | #include "src/ops/div.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -41,6 +43,30 @@ int Div::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| } | } | ||||
| int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); } | int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); } | ||||
| PrimitiveC *DivCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Div>(primitive); } | |||||
| Registry DivRegistry(schema::PrimitiveType_Div, DivCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||||
| arithmetic_param->activation_type_ = | |||||
| reinterpret_cast<mindspore::lite::Div *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| } | |||||
| Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/dropout.h" | #include "src/ops/dropout.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -39,6 +41,8 @@ int Dropout::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| } | } | ||||
| float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); } | float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); } | ||||
| PrimitiveC *DropoutCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Dropout>(primitive); } | |||||
| Registry DropoutRegistry(schema::PrimitiveType_Dropout, DropoutCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/eltwise.h" | #include "src/ops/eltwise.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/arithmetic_common.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -39,6 +42,35 @@ int Eltwise::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| } | } | ||||
| int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); } | int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); } | ||||
| PrimitiveC *EltwiseCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Eltwise>(primitive); } | |||||
| Registry EltwiseRegistry(schema::PrimitiveType_Eltwise, EltwiseCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| auto eltwise = reinterpret_cast<mindspore::lite::Eltwise *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| switch (eltwise->GetMode()) { | |||||
| case schema::EltwiseMode_PROD: | |||||
| arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul; | |||||
| break; | |||||
| case schema::EltwiseMode_SUM: | |||||
| arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add; | |||||
| break; | |||||
| case schema::EltwiseMode_MAXIMUM: | |||||
| arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum; | |||||
| break; | |||||
| default: | |||||
| free(arithmetic_param); | |||||
| return nullptr; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| } | |||||
| Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/elu.h" | #include "src/ops/elu.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "nnacl/fp32/elu.h" | |||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -61,6 +63,22 @@ int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| } | } | ||||
| float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); } | float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); } | ||||
| PrimitiveC *EluCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Elu>(primitive); } | |||||
| Registry EluRegistry(schema::PrimitiveType_Elu, EluCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateEluParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| EluParameter *elu_parameter = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter))); | |||||
| if (elu_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(elu_parameter, 0, sizeof(EluParameter)); | |||||
| elu_parameter->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::Elu *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| elu_parameter->alpha_ = param->GetAlpha(); | |||||
| return reinterpret_cast<OpParameter *>(elu_parameter); | |||||
| } | |||||
| Registry EluParameterRegistry(schema::PrimitiveType_Elu, PopulateEluParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/embedding_lookup.h" | #include "src/ops/embedding_lookup.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/embedding_lookup.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -41,8 +44,35 @@ int EmbeddingLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||||
| } | } | ||||
| float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); } | float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); } | ||||
| PrimitiveC *EmbeddingLookupCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<EmbeddingLookup>(primitive); | |||||
| } | |||||
| Registry EmbeddingLookupRegistry(schema::PrimitiveType_EmbeddingLookup, EmbeddingLookupCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateEmbeddingLookupParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| EmbeddingLookupParameter *embedding_lookup_parameter = | |||||
| reinterpret_cast<EmbeddingLookupParameter *>(malloc(sizeof(EmbeddingLookupParameter))); | |||||
| if (embedding_lookup_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(embedding_lookup_parameter, 0, sizeof(EmbeddingLookupParameter)); | |||||
| embedding_lookup_parameter->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::EmbeddingLookup *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| embedding_lookup_parameter->max_norm_ = param->GetMaxNorm(); | |||||
| if (embedding_lookup_parameter->max_norm_ < 0) { | |||||
| MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got " | |||||
| << embedding_lookup_parameter->max_norm_; | |||||
| free(embedding_lookup_parameter); | |||||
| return nullptr; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(embedding_lookup_parameter); | |||||
| } | |||||
| Registry EmbeddingLookupParameterRegistry(schema::PrimitiveType_EmbeddingLookup, PopulateEmbeddingLookupParameter); | |||||
| int EmbeddingLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int EmbeddingLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (inputs_.size() < kDoubleNum) { | if (inputs_.size() < kDoubleNum) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/embedding_lookup_sparse.h" | #include "src/ops/embedding_lookup_sparse.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -76,6 +78,10 @@ float EmbeddingLookupSparse::GetMaxNortm() const { | |||||
| return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm(); | return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm(); | ||||
| } | } | ||||
| PrimitiveC *EmbeddingLookupSparseCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<EmbeddingLookupSparse>(primitive); | |||||
| } | |||||
| Registry EmbeddingLookupSparseRegistry(schema::PrimitiveType_EmbeddingLookupSparse, EmbeddingLookupSparseCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/equal.h" | #include "src/ops/equal.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,8 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); } | |||||
| Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); | |||||
| #endif | #endif | ||||
| int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -39,5 +43,6 @@ int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/exp.h" | #include "src/ops/exp.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "src/ops/arithmetic_self.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -71,6 +74,10 @@ float Exp::GetBase() const { return this->primitive_->value_as_Exp()->base(); } | |||||
| float Exp::GetScale() const { return this->primitive_->value_as_Exp()->scale(); } | float Exp::GetScale() const { return this->primitive_->value_as_Exp()->scale(); } | ||||
| float Exp::GetShift() const { return this->primitive_->value_as_Exp()->shift(); } | float Exp::GetShift() const { return this->primitive_->value_as_Exp()->shift(); } | ||||
| PrimitiveC *ExpCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Exp>(primitive); } | |||||
| Registry ExpRegistry(schema::PrimitiveType_Exp, ExpCreator); | |||||
| #endif | #endif | ||||
| Registry ExpParameterRegistry(schema::PrimitiveType_Exp, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/expand_dims.h" | #include "src/ops/expand_dims.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/expandDims.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -40,8 +43,27 @@ int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff | |||||
| } | } | ||||
| int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); } | int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); } | ||||
| PrimitiveC *ExpandDimsCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<ExpandDims>(primitive); | |||||
| } | |||||
| Registry ExpandDimsRegistry(schema::PrimitiveType_ExpandDims, ExpandDimsCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto param = reinterpret_cast<mindspore::lite::ExpandDims *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| ExpandDimsParameter *expand_dims_param = reinterpret_cast<ExpandDimsParameter *>(malloc(sizeof(ExpandDimsParameter))); | |||||
| if (expand_dims_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ExpandDimsParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(expand_dims_param, 0, sizeof(ExpandDimsParameter)); | |||||
| expand_dims_param->op_parameter_.type_ = primitive->Type(); | |||||
| expand_dims_param->dim_ = param->GetDim(); | |||||
| return reinterpret_cast<OpParameter *>(expand_dims_param); | |||||
| } | |||||
| Registry ExpandDimsParameterRegistry(schema::PrimitiveType_ExpandDims, PopulateExpandDimsParameter); | |||||
| int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/fake_quant_with_min_max_vars.h" | #include "src/ops/fake_quant_with_min_max_vars.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -54,6 +56,10 @@ int FakeQuantWithMinMaxVars::GetNumBits() const { | |||||
| return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits(); | return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits(); | ||||
| } | } | ||||
| PrimitiveC *FakeQuantWithMinMaxVarsCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<FakeQuantWithMinMaxVars>(primitive); | |||||
| } | |||||
| Registry FakeQuantWithMinMaxVarsRegistry(schema::PrimitiveType_FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/fill.h" | #include "src/ops/fill.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/fill.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -48,8 +51,30 @@ std::vector<int> Fill::GetDims() const { | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Fill>(primitive); } | |||||
| Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateFillParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| const auto param = reinterpret_cast<mindspore::lite::Fill *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| FillParameter *fill_param = reinterpret_cast<FillParameter *>(malloc(sizeof(FillParameter))); | |||||
| if (fill_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc FillParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(fill_param, 0, sizeof(FillParameter)); | |||||
| fill_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto flatDims = param->GetDims(); | |||||
| fill_param->num_dims_ = flatDims.size(); | |||||
| int i = 0; | |||||
| for (auto iter = flatDims.begin(); iter != flatDims.end(); iter++) { | |||||
| fill_param->dims_[i++] = *iter; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(fill_param); | |||||
| } | |||||
| Registry FillParameterRegistry(schema::PrimitiveType_Fill, PopulateFillParameter); | |||||
| int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -17,6 +17,9 @@ | |||||
| #include "src/ops/flatten.h" | #include "src/ops/flatten.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/flatten.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -86,6 +89,22 @@ int Flatten::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *FlattenCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Flatten>(primitive); } | |||||
| Registry FlattenRegistry(schema::PrimitiveType_Flatten, FlattenCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateFlattenParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| FlattenParameter *flatten_param = reinterpret_cast<FlattenParameter *>(malloc(sizeof(FlattenParameter))); | |||||
| if (flatten_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc FlattenParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(flatten_param, 0, sizeof(FlattenParameter)); | |||||
| flatten_param->op_parameter_.type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(flatten_param); | |||||
| } | |||||
| Registry FlattenParameterRegistry(schema::PrimitiveType_Flatten, PopulateFlattenParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/ops/flatten_grad.h" | #include "src/ops/flatten_grad.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| @@ -85,6 +87,10 @@ int FlattenGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *FlattenGradCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<FlattenGrad>(primitive); | |||||
| } | |||||
| Registry FlattenGradRegistry(schema::PrimitiveType_FlattenGrad, FlattenGradCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/floor.h" | #include "src/ops/floor.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -28,7 +30,10 @@ int Floor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *FloorCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Floor>(primitive); } | |||||
| Registry FloorRegistry(schema::PrimitiveType_Floor, FloorCreator); | |||||
| #endif | #endif | ||||
| Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/floor_div.h" | #include "src/ops/floor_div.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -29,6 +31,11 @@ int FloorDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *FloorDivCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<FloorDiv>(primitive); | |||||
| } | |||||
| Registry FloorDivRegistry(schema::PrimitiveType_FloorDiv, FloorDivCreator); | |||||
| #endif | #endif | ||||
| Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/floor_mod.h" | #include "src/ops/floor_mod.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -28,7 +30,11 @@ int FloorMod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *FloorModCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<FloorMod>(primitive); | |||||
| } | |||||
| Registry FloorModRegistry(schema::PrimitiveType_FloorMod, FloorModCreator); | |||||
| #endif | #endif | ||||
| Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/full_connection.h" | #include "src/ops/full_connection.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -51,7 +54,38 @@ int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConn | |||||
| bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); } | bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); } | ||||
| int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); } | int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); } | ||||
| PrimitiveC *FullConnectionCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<FullConnection>(primitive); | |||||
| } | |||||
| Registry FullConnectionRegistry(schema::PrimitiveType_FullConnection, FullConnectionCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateFullconnectionParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::FullConnection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| MatMulParameter *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (matmul_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MatMulParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(matmul_param, 0, sizeof(MatMulParameter)); | |||||
| matmul_param->op_parameter_.type_ = primitive->Type(); | |||||
| matmul_param->b_transpose_ = true; | |||||
| matmul_param->a_transpose_ = false; | |||||
| matmul_param->has_bias_ = param->GetHasBias(); | |||||
| if (param->GetActivationType() == schema::ActivationType_RELU) { | |||||
| matmul_param->act_type_ = ActType_Relu; | |||||
| } else if (param->GetActivationType() == schema::ActivationType_RELU6) { | |||||
| matmul_param->act_type_ = ActType_Relu6; | |||||
| } else { | |||||
| matmul_param->act_type_ = ActType_No; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(matmul_param); | |||||
| } | |||||
| Registry FullConnectionParameterRegistry(schema::PrimitiveType_FullConnection, PopulateFullconnectionParameter); | |||||
| int FullConnection::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int FullConnection::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input0 = inputs_.front(); | auto input0 = inputs_.front(); | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/fused_batchnorm.h" | #include "src/ops/fused_batchnorm.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/batchnorm_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -72,7 +75,29 @@ float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_Fus | |||||
| float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); } | float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); } | ||||
| int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); } | int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); } | ||||
| PrimitiveC *FusedBatchNormCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<FusedBatchNorm>(primitive); | |||||
| } | |||||
| Registry FusedBatchNormRegistry(schema::PrimitiveType_FusedBatchNorm, FusedBatchNormCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive) { | |||||
| BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||||
| if (batch_norm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc BatchNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(batch_norm_param, 0, sizeof(BatchNormParameter)); | |||||
| batch_norm_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::FusedBatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| batch_norm_param->epsilon_ = param->GetEpsilon(); | |||||
| batch_norm_param->momentum_ = param->GetMomentum(); | |||||
| batch_norm_param->fused_ = true; | |||||
| return reinterpret_cast<OpParameter *>(batch_norm_param); | |||||
| } | |||||
| Registry FusedBatchNormParameterRegistry(schema::PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm); | |||||
| int FusedBatchNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int FusedBatchNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| for (size_t i = 0; i < inputs_.size(); i++) { | for (size_t i = 0; i < inputs_.size(); i++) { | ||||
| if (outputs_.size() <= i) break; | if (outputs_.size() <= i) break; | ||||
| @@ -19,6 +19,9 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/gather_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -46,8 +49,25 @@ int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); } | int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); } | ||||
| int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); } | int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); } | ||||
| PrimitiveC *GatherCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gather>(primitive); } | |||||
| Registry GatherRegistry(schema::PrimitiveType_Gather, GatherCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateGatherParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto gather_attr = reinterpret_cast<mindspore::lite::Gather *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| GatherParameter *gather_param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter))); | |||||
| if (gather_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GatherParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(gather_param, 0, sizeof(GatherParameter)); | |||||
| gather_param->op_parameter_.type_ = primitive->Type(); | |||||
| gather_param->axis_ = gather_attr->GetAxis(); | |||||
| gather_param->batchDims_ = gather_attr->GetBatchDims(); | |||||
| return reinterpret_cast<OpParameter *>(gather_param); | |||||
| } | |||||
| Registry GatherParameterRegistry(schema::PrimitiveType_Gather, PopulateGatherParameter); | |||||
| int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (inputs_.size() != kDoubleNum) { | if (inputs_.size() != kDoubleNum) { | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/gather_nd.h" | #include "src/ops/gather_nd.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/gatherNd.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -40,8 +43,28 @@ int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer | |||||
| } | } | ||||
| int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); } | int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); } | ||||
| PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<GatherNd>(primitive); | |||||
| } | |||||
| Registry GatherNdRegistry(schema::PrimitiveType_GatherNd, GatherNdCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| GatherNdParameter *gather_nd_param = reinterpret_cast<GatherNdParameter *>(malloc(sizeof(GatherNdParameter))); | |||||
| if (gather_nd_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GatherNdParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(gather_nd_param, 0, sizeof(GatherNdParameter)); | |||||
| gather_nd_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto gatherNd_attr = | |||||
| reinterpret_cast<mindspore::lite::GatherNd *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims(); | |||||
| return reinterpret_cast<OpParameter *>(gather_nd_param); | |||||
| } | |||||
| Registry GatherNdParameterRegistry(schema::PrimitiveType_GatherNd, PopulateGatherNdParameter); | |||||
| int GatherNd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int GatherNd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (inputs_.size() != kDoubleNum) { | if (inputs_.size() != kDoubleNum) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/greater.h" | #include "src/ops/greater.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,9 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); } | |||||
| Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); | |||||
| #endif | #endif | ||||
| int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -39,5 +44,6 @@ int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/greater_equal.h" | #include "src/ops/greater_equal.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -27,6 +29,12 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<GreaterEqual>(primitive); | |||||
| } | |||||
| Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); | |||||
| #endif | #endif | ||||
| int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -38,5 +46,6 @@ int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor * | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/group_conv2d_grad_input.h" | #include "src/ops/group_conv2d_grad_input.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -127,6 +129,10 @@ int GroupConv2DGradInput::GetActivationType() const { | |||||
| return this->primitive_->value_as_GroupConv2DGradInput()->activationType(); | return this->primitive_->value_as_GroupConv2DGradInput()->activationType(); | ||||
| } | } | ||||
| PrimitiveC *GroupConv2DGradInputCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<GroupConv2DGradInput>(primitive); | |||||
| } | |||||
| Registry GroupConv2DGradInputRegistry(schema::PrimitiveType_GroupConv2DGradInput, GroupConv2DGradInputCreator); | |||||
| #endif | #endif | ||||
| int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/common/string_util.h" | #include "src/common/string_util.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -30,7 +32,24 @@ int HashtableLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *HashtableLookupCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<HashtableLookup>(primitive); | |||||
| } | |||||
| Registry HashtableLookupRegistry(schema::PrimitiveType_HashtableLookup, HashtableLookupCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateHashtableLookupParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->Type(); | |||||
| return param; | |||||
| } | |||||
| Registry HashtableLookupParameterRegistry(schema::PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter); | |||||
| int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.at(0); | auto input = inputs_.at(0); | ||||
| auto values = inputs_.at(2); | auto values = inputs_.at(2); | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/l2_norm.h" | #include "src/ops/l2_norm.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/l2_norm_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -57,6 +60,44 @@ std::vector<int> L2Norm::GetAxis() const { | |||||
| float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); } | float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); } | ||||
| int L2Norm::GetActivationType() const { return this->primitive_->value_as_L2Norm()->activationType(); } | int L2Norm::GetActivationType() const { return this->primitive_->value_as_L2Norm()->activationType(); } | ||||
| PrimitiveC *L2NormCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<L2Norm>(primitive); } | |||||
| Registry L2NormRegistry(schema::PrimitiveType_L2Norm, L2NormCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateL2NormParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| L2NormParameter *l2_norm_parameter = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter))); | |||||
| if (l2_norm_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc L2NormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(l2_norm_parameter, 0, sizeof(L2NormParameter)); | |||||
| l2_norm_parameter->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::L2Norm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| auto axis_vec = param->GetAxis(); | |||||
| l2_norm_parameter->axis_num_ = axis_vec.size(); | |||||
| l2_norm_parameter->axis_ = reinterpret_cast<int *>(malloc(axis_vec.size() * sizeof(int))); | |||||
| if (l2_norm_parameter->axis_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc axis_ data failed"; | |||||
| free(l2_norm_parameter); | |||||
| return nullptr; | |||||
| } | |||||
| for (size_t i = 0; i < axis_vec.size(); i++) { | |||||
| l2_norm_parameter->axis_[i] = axis_vec[i]; | |||||
| } | |||||
| if (param->GetEpsilon() < 1e-6) { | |||||
| l2_norm_parameter->epsilon_ = 1e-6; | |||||
| } else { | |||||
| l2_norm_parameter->epsilon_ = param->GetEpsilon(); | |||||
| } | |||||
| if (param->GetActivationType() == static_cast<int>(schema::ActivationType_RELU)) { | |||||
| l2_norm_parameter->act_type_ = ActType_Relu; | |||||
| } else if (param->GetActivationType() == static_cast<int>(schema::ActivationType_RELU6)) { | |||||
| l2_norm_parameter->act_type_ = ActType_Relu6; | |||||
| } else { | |||||
| l2_norm_parameter->act_type_ = ActType_No; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(l2_norm_parameter); | |||||
| } | |||||
| Registry L2NormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateL2NormParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/leaky_relu.h" | #include "src/ops/leaky_relu.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -42,6 +44,10 @@ int LeakyReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LeakyReLUCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LeakyReLU>(primitive); | |||||
| } | |||||
| Registry LeakyReLURegistry(schema::PrimitiveType_LeakyReLU, LeakyReLUCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/less.h" | #include "src/ops/less.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -29,6 +31,10 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Less>(primitive); } | |||||
| Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); | |||||
| #endif | #endif | ||||
| int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -40,5 +46,6 @@ int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/less_equal.h" | #include "src/ops/less_equal.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,10 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LessEqual>(primitive); | |||||
| } | |||||
| Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); | |||||
| #endif | #endif | ||||
| int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -39,5 +45,6 @@ int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/local_response_normalization.h" | #include "src/ops/local_response_normalization.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/local_response_norm.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -76,6 +79,34 @@ int LocalResponseNormalization::UnPackToFlatBuilder(const schema::Primitive *pri | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LocalResponseNormalizationCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LocalResponseNormalization>(primitive); | |||||
| } | |||||
| Registry LocalResponseNormalizationRegistry(schema::PrimitiveType_LocalResponseNormalization, | |||||
| LocalResponseNormalizationCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateLocalResponseNormParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto local_response_norm_attr = reinterpret_cast<mindspore::lite::LocalResponseNormalization *>( | |||||
| const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| LocalResponseNormParameter *lrn_param = | |||||
| reinterpret_cast<LocalResponseNormParameter *>(malloc(sizeof(LocalResponseNormParameter))); | |||||
| if (lrn_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lrn_param, 0, sizeof(LocalResponseNormParameter)); | |||||
| lrn_param->op_parameter_.type_ = primitive->Type(); | |||||
| lrn_param->depth_radius_ = local_response_norm_attr->GetDepthRadius(); | |||||
| lrn_param->bias_ = local_response_norm_attr->GetBias(); | |||||
| lrn_param->alpha_ = local_response_norm_attr->GetAlpha(); | |||||
| lrn_param->beta_ = local_response_norm_attr->GetBeta(); | |||||
| return reinterpret_cast<OpParameter *>(lrn_param); | |||||
| } | |||||
| Registry LocalResponseNormalizationParameterRegistry(schema::PrimitiveType_LocalResponseNormalization, | |||||
| PopulateLocalResponseNormParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/ops/log.h" | #include "src/ops/log.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -50,6 +52,11 @@ int Log::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LogCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Log>(primitive); } | |||||
| Registry LogRegistry(schema::PrimitiveType_Log, LogCreator); | |||||
| #endif | #endif | ||||
| Registry LogParameterRegistry(schema::PrimitiveType_Log, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/log_grad.h" | #include "src/ops/log_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "src/ops/arithmetic_self.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -32,6 +35,11 @@ int LogGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LogGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<LogGrad>(primitive); } | |||||
| Registry LogGradRegistry(schema::PrimitiveType_LogGrad, LogGradCreator); | |||||
| #endif | #endif | ||||
| Registry LogGradParameterRegistry(schema::PrimitiveType_LogGrad, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/logical_and.h" | #include "src/ops/logical_and.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,13 @@ int LogicalAnd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LogicalAndCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LogicalAnd>(primitive); | |||||
| } | |||||
| Registry LogicalAndRegistry(schema::PrimitiveType_LogicalAnd, LogicalAndCreator); | |||||
| #endif | #endif | ||||
| Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/logical_not.h" | #include "src/ops/logical_not.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,12 @@ int LogicalNot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LogicalNotCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LogicalNot>(primitive); | |||||
| } | |||||
| Registry LogicalNotRegistry(schema::PrimitiveType_LogicalNot, LogicalNotCreator); | |||||
| #endif | #endif | ||||
| Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/logical_or.h" | #include "src/ops/logical_or.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,12 @@ int LogicalOr::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LogicalOrCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LogicalOr>(primitive); | |||||
| } | |||||
| Registry LogicalOrRegistry(schema::PrimitiveType_LogicalOr, LogicalOrCreator); | |||||
| #endif | #endif | ||||
| Registry LogicalOrParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/lrn.h" | #include "src/ops/lrn.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -49,6 +51,9 @@ int Lrn::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LrnCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Lrn>(primitive); } | |||||
| Registry LrnRegistry(schema::PrimitiveType_Lrn, LrnCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "nnacl/lsh_projection_parameter.h" | #include "nnacl/lsh_projection_parameter.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -38,7 +40,29 @@ int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LshProjectionCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<LshProjection>(primitive); | |||||
| } | |||||
| Registry LshProjectionRegistry(schema::PrimitiveType_LshProjection, LshProjectionCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateLshProjectionParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| LshProjectionParameter *lsh_project_param = | |||||
| reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter))); | |||||
| if (lsh_project_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LshProjectionParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lsh_project_param, 0, sizeof(LshProjectionParameter)); | |||||
| lsh_project_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::LshProjection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| lsh_project_param->lsh_type_ = param->GetLshType(); | |||||
| return reinterpret_cast<OpParameter *>(lsh_project_param); | |||||
| } | |||||
| Registry LshProjectionParameterRegistry(schema::PrimitiveType_LshProjection, PopulateLshProjectionParameter); | |||||
| int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { | if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { | ||||
| MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given."; | MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given."; | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/lstm.h" | #include "src/ops/lstm.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/lstm.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -39,8 +42,31 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *LstmCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Lstm>(primitive); } | |||||
| Registry LstmRegistry(schema::PrimitiveType_Lstm, LstmCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| LstmParameter *lstm_param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter))); | |||||
| if (lstm_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LstmParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lstm_param, 0, sizeof(LstmParameter)); | |||||
| lstm_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::Lstm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| if (param == nullptr) { | |||||
| free(lstm_param); | |||||
| MS_LOG(ERROR) << "get Lstm param nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| lstm_param->bidirectional_ = param->GetBidirection(); | |||||
| return reinterpret_cast<OpParameter *>(lstm_param); | |||||
| } | |||||
| Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); | |||||
| const int kLstmInputNum = 6; | const int kLstmInputNum = 6; | ||||
| const int kLstmOutputNum = 3; | const int kLstmOutputNum = 3; | ||||
| int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| @@ -18,6 +18,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -57,6 +59,11 @@ int MakeTuple::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MakeTupleCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<MakeTuple>(primitive); | |||||
| } | |||||
| Registry MakeTupleRegistry(schema::PrimitiveType_MakeTuple, MakeTupleCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,6 +21,9 @@ | |||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #endif | #endif | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -86,8 +89,27 @@ int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MatMulCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<MatMul>(primitive); } | |||||
| Registry MatMulRegistry(schema::PrimitiveType_MatMul, MatMulCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateMatMulParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto param = reinterpret_cast<mindspore::lite::MatMul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| MatMulParameter *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (matmul_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc MatMulParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(matmul_param, 0, sizeof(MatMulParameter)); | |||||
| matmul_param->op_parameter_.type_ = primitive->Type(); | |||||
| matmul_param->b_transpose_ = param->GetTransposeB(); | |||||
| matmul_param->a_transpose_ = param->GetTransposeA(); | |||||
| matmul_param->has_bias_ = false; | |||||
| matmul_param->act_type_ = ActType_No; | |||||
| return reinterpret_cast<OpParameter *>(matmul_param); | |||||
| } | |||||
| Registry MatMulParameterRegistry(schema::PrimitiveType_MatMul, PopulateMatMulParameter); | |||||
| int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input0 = inputs_.front(); | auto input0 = inputs_.front(); | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/matrix_diag.h" | #include "src/ops/matrix_diag.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -52,6 +54,11 @@ int MatrixDiag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MatrixDiagCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<MatrixDiag>(primitive); | |||||
| } | |||||
| Registry MatrixDiagRegistry(schema::PrimitiveType_MatrixDiag, MatrixDiagCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,8 @@ | |||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #endif | #endif | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -62,6 +64,10 @@ int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MaximumCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Maximum>(primitive); } | |||||
| Registry MaximumRegistry(schema::PrimitiveType_Maximum, MaximumCreator); | |||||
| #endif | #endif | ||||
| Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/mean.h" | #include "src/ops/mean.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/reduce_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -53,8 +56,36 @@ int Mean::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MeanCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mean>(primitive); } | |||||
| Registry MeanRegistry(schema::PrimitiveType_Mean, MeanCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateMeanParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ReduceParameter *mean_param = reinterpret_cast<ReduceParameter *>(malloc(sizeof(ReduceParameter))); | |||||
| if (mean_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ReduceParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(mean_param, 0, sizeof(ReduceParameter)); | |||||
| mean_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto mean = reinterpret_cast<mindspore::lite::Mean *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| mean_param->keep_dims_ = mean->GetKeepDims(); | |||||
| auto axisVector = mean->GetAxis(); | |||||
| if (axisVector.size() > REDUCE_MAX_AXES_NUM) { | |||||
| MS_LOG(ERROR) << "Reduce axes size " << axisVector.size() << " exceed limit " << REDUCE_MAX_AXES_NUM; | |||||
| free(mean_param); | |||||
| return nullptr; | |||||
| } | |||||
| mean_param->num_axes_ = static_cast<int>(axisVector.size()); | |||||
| int i = 0; | |||||
| for (auto iter = axisVector.begin(); iter != axisVector.end(); iter++) { | |||||
| mean_param->axes_[i++] = *iter; | |||||
| } | |||||
| mean_param->mode_ = static_cast<int>(schema::ReduceMode_ReduceMean); | |||||
| return reinterpret_cast<OpParameter *>(mean_param); | |||||
| } | |||||
| Registry MeanParameterRegistry(schema::PrimitiveType_Mean, PopulateMeanParameter); | |||||
| namespace { | namespace { | ||||
| constexpr size_t kInputSize = 1; | constexpr size_t kInputSize = 1; | ||||
| constexpr size_t kOutputSize = 1; | constexpr size_t kOutputSize = 1; | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/minimum.h" | #include "src/ops/minimum.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,10 @@ int Minimum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MinimumCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Minimum>(primitive); } | |||||
| Registry MinimumRegistry(schema::PrimitiveType_Minimum, MinimumCreator); | |||||
| #endif | #endif | ||||
| Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/mul.h" | #include "src/ops/mul.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "nnacl/arithmetic_common.h" | |||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -72,6 +74,30 @@ int Mul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *MulCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mul>(primitive); } | |||||
| Registry MulRegistry(schema::PrimitiveType_Mul, MulCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (arithmetic_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); | |||||
| arithmetic_param->op_parameter_.type_ = primitive->Type(); | |||||
| arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); | |||||
| arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); | |||||
| arithmetic_param->activation_type_ = | |||||
| reinterpret_cast<mindspore::lite::Mul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType(); | |||||
| auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); | |||||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); | |||||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); | |||||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||||
| } | |||||
| Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,9 @@ | |||||
| #include "src/ops/nchw2nhwc.h" | #include "src/ops/nchw2nhwc.h" | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/transpose.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -29,8 +32,29 @@ int Nchw2Nhwc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *Nchw2NhwcCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<Nchw2Nhwc>(primitive); | |||||
| } | |||||
| Registry Nchw2NhwcRegistry(schema::PrimitiveType_Nchw2Nhwc, Nchw2NhwcCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateNchw2NhwcParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| TransposeParameter *parameter = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter))); | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(parameter, 0, sizeof(OpParameter)); | |||||
| parameter->op_parameter_.type_ = primitive->Type(); | |||||
| parameter->num_axes_ = 4; | |||||
| parameter->perm_[0] = 0; | |||||
| parameter->perm_[1] = 2; | |||||
| parameter->perm_[2] = 3; | |||||
| parameter->perm_[3] = 1; | |||||
| return reinterpret_cast<OpParameter *>(parameter); | |||||
| } | |||||
| Registry Nchw2NhwcParameterRegistry(schema::PrimitiveType_Nchw2Nhwc, PopulateNchw2NhwcParameter); | |||||
| int Nchw2Nhwc::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int Nchw2Nhwc::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/neg.h" | #include "src/ops/neg.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -52,6 +54,9 @@ int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *NegCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Neg>(primitive); } | |||||
| Registry NegRegistry(schema::PrimitiveType_Neg, NegCreator); | |||||
| #endif | #endif | ||||
| Registry NegParameterRegistry(schema::PrimitiveType_Neg, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,11 @@ int NegGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *NegGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NegGrad>(primitive); } | |||||
| Registry NegGradRegistry(schema::PrimitiveType_NegGrad, NegGradCreator); | |||||
| #endif | #endif | ||||
| Registry NegGradParameterRegistry(schema::PrimitiveType_NegGrad, PopulateArithmeticSelf); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,9 @@ | |||||
| #include "src/ops/nhwc2nchw.h" | #include "src/ops/nhwc2nchw.h" | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/transpose.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -30,8 +33,30 @@ int Nhwc2Nchw::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *Nhwc2NchwCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<Nhwc2Nchw>(primitive); | |||||
| } | |||||
| Registry Nhwc2NchwRegistry(schema::PrimitiveType_Nhwc2Nchw, Nhwc2NchwCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateNhwc2NchwParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| TransposeParameter *parameter = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter))); | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(parameter, 0, sizeof(OpParameter)); | |||||
| parameter->op_parameter_.type_ = primitive->Type(); | |||||
| parameter->num_axes_ = 4; | |||||
| parameter->perm_[0] = 0; | |||||
| parameter->perm_[1] = 3; | |||||
| parameter->perm_[2] = 1; | |||||
| parameter->perm_[3] = 2; | |||||
| return reinterpret_cast<OpParameter *>(parameter); | |||||
| } | |||||
| Registry Nhwc2NchwParameterRegistry(schema::PrimitiveType_Nhwc2Nchw, PopulateNhwc2NchwParameter); | |||||
| int Nhwc2Nchw::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | int Nhwc2Nchw::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/not_equal.h" | #include "src/ops/not_equal.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -28,6 +30,11 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<NotEqual>(primitive); | |||||
| } | |||||
| Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); | |||||
| #endif | #endif | ||||
| int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| @@ -39,5 +46,6 @@ int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/one_hot.h" | #include "src/ops/one_hot.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/fp32/one_hot.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -71,8 +74,30 @@ int OneHot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *OneHotCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<OneHot>(primitive); } | |||||
| Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OneHotParameter *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter))); | |||||
| if (one_hot_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OneHotParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(one_hot_param, 0, sizeof(OneHotParameter)); | |||||
| one_hot_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::OneHot *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| if (param == nullptr) { | |||||
| free(one_hot_param); | |||||
| MS_LOG(ERROR) << "get OneHot param nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| one_hot_param->axis_ = param->GetAxis(); | |||||
| return reinterpret_cast<OpParameter *>(one_hot_param); | |||||
| } | |||||
| Registry OneHotParameterRegistry(schema::PrimitiveType_OneHot, PopulateOneHotParameter); | |||||
| namespace { | namespace { | ||||
| constexpr size_t kOneHotInputNum = 4; | constexpr size_t kOneHotInputNum = 4; | ||||
| } | } | ||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H | |||||
| #include <map> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class OpsRegistry { | |||||
| public: | |||||
| static OpsRegistry *GetInstance() { | |||||
| static OpsRegistry registry; | |||||
| return ®istry; | |||||
| } | |||||
| void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { | |||||
| primitive_creators[type] = creator; | |||||
| } | |||||
| PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) { | |||||
| if (primitive_creators.find(type) != primitive_creators.end()) { | |||||
| return primitive_creators[type]; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(type); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } | |||||
| ParameterCreator getParameterCreator(schema::PrimitiveType type) { | |||||
| if (parameter_creators.find(type) != parameter_creators.end()) { | |||||
| return parameter_creators[type]; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported parameter type in Create : " << schema::EnumNamePrimitiveType(type); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| protected: | |||||
| std::map<schema::PrimitiveType, PrimitiveCCreator> primitive_creators; | |||||
| std::map<schema::PrimitiveType, ParameterCreator> parameter_creators; | |||||
| }; | |||||
| class Registry { | |||||
| public: | |||||
| Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) { | |||||
| OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator); | |||||
| } | |||||
| Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) { | |||||
| OpsRegistry::GetInstance()->insertParameterMap(primitive_type, creator); | |||||
| } | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H | |||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/p_relu.h" | #include "src/ops/p_relu.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/prelu_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -46,6 +49,24 @@ int PReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PReLUCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<PReLU>(primitive); } | |||||
| Registry PReLURegistry(schema::PrimitiveType_PReLU, PReLUCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto param = reinterpret_cast<mindspore::lite::PReLU *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| PReluParameter *prelu_param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter))); | |||||
| if (prelu_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PReluParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(prelu_param, 0, sizeof(PReluParameter)); | |||||
| prelu_param->op_parameter_.type_ = primitive->Type(); | |||||
| prelu_param->channelShared = param->GetChannelShared(); | |||||
| return reinterpret_cast<OpParameter *>(prelu_param); | |||||
| } | |||||
| Registry PReLUParameterRegistry(schema::PrimitiveType_PReLU, PopulatePReLUParameter); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/pad.h" | #include "src/ops/pad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/pad_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -57,7 +60,42 @@ int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PadCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Pad>(primitive); } | |||||
| Registry PadRegistry(schema::PrimitiveType_Pad, PadCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| PadParameter *pad_param = reinterpret_cast<PadParameter *>(malloc(sizeof(PadParameter))); | |||||
| if (pad_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PadParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(pad_param, 0, sizeof(PadParameter)); | |||||
| pad_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| pad_param->pad_mode_ = pad_node->GetPaddingMode(); | |||||
| if (pad_param->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) { | |||||
| pad_param->constant_value_ = pad_node->GetConstantValue(); | |||||
| auto size = pad_node->GetPaddings().size(); | |||||
| if (size > MAX_PAD_SIZE) { | |||||
| MS_LOG(ERROR) << "Invalid padding size: " << size; | |||||
| free(pad_param); | |||||
| return nullptr; | |||||
| } | |||||
| for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) { | |||||
| pad_param->paddings_[i] = 0; | |||||
| } | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i]; | |||||
| } | |||||
| pad_param->padding_length = MAX_PAD_SIZE; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(pad_param); | |||||
| } | |||||
| Registry PadParameterRegistry(schema::PrimitiveType_Pad, PopulatePadParameter); | |||||
| int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (this->primitive_ == nullptr) { | if (this->primitive_ == nullptr) { | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/permute.h" | #include "src/ops/permute.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -50,6 +52,9 @@ int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Permute>(primitive); } | |||||
| Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,9 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/pooling_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -158,8 +161,73 @@ int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PoolingCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Pooling>(primitive); } | |||||
| Registry PoolingRegistry(schema::PrimitiveType_Pooling, PoolingCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto pooling_primitive = | |||||
| reinterpret_cast<mindspore::lite::Pooling *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter))); | |||||
| if (pooling_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PoolingParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(pooling_param, 0, sizeof(PoolingParameter)); | |||||
| pooling_param->op_parameter_.type_ = primitive->Type(); | |||||
| pooling_param->global_ = pooling_primitive->GetGlobal(); | |||||
| pooling_param->window_w_ = pooling_primitive->GetWindowW(); | |||||
| pooling_param->window_h_ = pooling_primitive->GetWindowH(); | |||||
| auto pooling_lite_primitive = (lite::Pooling *)primitive; | |||||
| pooling_param->pad_u_ = pooling_lite_primitive->PadUp(); | |||||
| pooling_param->pad_d_ = pooling_lite_primitive->PadDown(); | |||||
| pooling_param->pad_l_ = pooling_lite_primitive->PadLeft(); | |||||
| pooling_param->pad_r_ = pooling_lite_primitive->PadRight(); | |||||
| pooling_param->stride_w_ = pooling_primitive->GetStrideW(); | |||||
| pooling_param->stride_h_ = pooling_primitive->GetStrideH(); | |||||
| pooling_param->avg_mode_ = pooling_primitive->GetAvgMode(); | |||||
| auto is_global = pooling_primitive->GetGlobal(); | |||||
| pooling_param->global_ = is_global; | |||||
| auto pool_mode = pooling_primitive->GetPoolingMode(); | |||||
| switch (pool_mode) { | |||||
| case schema::PoolMode_MAX_POOLING: | |||||
| pooling_param->pool_mode_ = PoolMode_MaxPool; | |||||
| break; | |||||
| case schema::PoolMode_MEAN_POOLING: | |||||
| pooling_param->pool_mode_ = PoolMode_AvgPool; | |||||
| break; | |||||
| default: | |||||
| pooling_param->pool_mode_ = PoolMode_No; | |||||
| break; | |||||
| } | |||||
| auto round_mode = pooling_primitive->GetRoundMode(); | |||||
| switch (round_mode) { | |||||
| case schema::RoundMode_FLOOR: | |||||
| pooling_param->round_mode_ = RoundMode_Floor; | |||||
| break; | |||||
| case schema::RoundMode_CEIL: | |||||
| pooling_param->round_mode_ = RoundMode_Ceil; | |||||
| break; | |||||
| default: | |||||
| pooling_param->round_mode_ = RoundMode_No; | |||||
| break; | |||||
| } | |||||
| if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU) { | |||||
| pooling_param->act_type_ = ActType_Relu; | |||||
| } else if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU6) { | |||||
| pooling_param->act_type_ = ActType_Relu6; | |||||
| } else { | |||||
| pooling_param->act_type_ = ActType_No; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(pooling_param); | |||||
| } | |||||
| Registry PoolingParameterRegistry(schema::PrimitiveType_Pooling, PopulatePoolingParameter); | |||||
| int Pooling::PadUp() const { return this->pad_u_; } | int Pooling::PadUp() const { return this->pad_u_; } | ||||
| int Pooling::PadDown() const { return this->pad_d_; } | int Pooling::PadDown() const { return this->pad_d_; } | ||||
| int Pooling::PadLeft() const { return this->pad_l_; } | int Pooling::PadLeft() const { return this->pad_l_; } | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/pooling_grad.h" | #include "src/ops/pooling_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -142,6 +144,11 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PoolingGradCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<PoolingGrad>(primitive); | |||||
| } | |||||
| Registry PoolingGradRegistry(schema::PrimitiveType_PoolingGrad, PoolingGradCreator); | |||||
| #endif | #endif | ||||
| int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "src/ops/power.h" | #include "src/ops/power.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| #include "nnacl/power_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -45,8 +48,28 @@ int Power::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PowerCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Power>(primitive); } | |||||
| Registry PowerRegistry(schema::PrimitiveType_Power, PowerCreator); | |||||
| #endif | #endif | ||||
| OpParameter *PopulatePowerParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| PowerParameter *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter))); | |||||
| if (power_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc PowerParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(power_param, 0, sizeof(PowerParameter)); | |||||
| power_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto power = reinterpret_cast<mindspore::lite::Power *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| power_param->power_ = power->GetPower(); | |||||
| power_param->scale_ = power->GetScale(); | |||||
| power_param->shift_ = power->GetShift(); | |||||
| return reinterpret_cast<OpParameter *>(power_param); | |||||
| } | |||||
| Registry PowerParameterRegistry(schema::PrimitiveType_Power, PopulatePowerParameter); | |||||
| int Power::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int Power::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto x_tensor = inputs[0]; | auto x_tensor = inputs[0]; | ||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/ops/power_grad.h" | #include "src/ops/power_grad.h" | ||||
| #include "src/ops/ops_register.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -77,6 +79,11 @@ int PowerGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *PowerGradCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<PowerGrad>(primitive); | |||||
| } | |||||
| Registry PowerGradRegistry(schema::PrimitiveType_PowerGrad, PowerGradCreator); | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||