| @@ -651,184 +651,184 @@ schema::PrimitiveT *ZerosLikePrimitiveCreator(const AnfNodePtr &node) { | |||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||
| } | |||
| RegistryMSOps g_AbsPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | |||
| RegistryMSOps g_ActivationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | |||
| RegistryMSOps g_ActivationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator); | |||
| RegistryMSOps g_AddPrimitiveCreatorRegistry("Add", AddFusionPrimitiveCreator); | |||
| RegistryMSOps g_AddFusionPrimitiveCreatorRegistry("AddFusion", AddFusionPrimitiveCreator); | |||
| RegistryMSOps g_AddGradPrimitiveCreatorRegistry("AddGrad", AddGradPrimitiveCreator); | |||
| RegistryMSOps g_AdderPrimitiveCreatorRegistry("Adder", AdderFusionPrimitiveCreator); | |||
| RegistryMSOps g_AdderFusionPrimitiveCreatorRegistry("AdderFusion", AdderFusionPrimitiveCreator); | |||
| RegistryMSOps g_AddNPrimitiveCreatorRegistry("AddN", AddNPrimitiveCreator); | |||
| RegistryMSOps g_AllPrimitiveCreatorRegistry("All", AllPrimitiveCreator); | |||
| RegistryMSOps g_ApplyMomentumPrimitiveCreatorRegistry("ApplyMomentum", ApplyMomentumPrimitiveCreator); | |||
| RegistryMSOps g_ArgMaxPrimitiveCreatorRegistry("ArgMax", ArgMaxFusionPrimitiveCreator); | |||
| RegistryMSOps g_ArgMaxFusionPrimitiveCreatorRegistry("ArgMaxFusion", ArgMaxFusionPrimitiveCreator); | |||
| RegistryMSOps g_ArgMinPrimitiveCreatorRegistry("ArgMin", ArgMinFusionPrimitiveCreator); | |||
| RegistryMSOps g_ArgMinFusionPrimitiveCreatorRegistry("ArgMinFusion", ArgMinFusionPrimitiveCreator); | |||
| RegistryMSOps g_AssertPrimitiveCreatorRegistry("Assert", AssertPrimitiveCreator); | |||
| RegistryMSOps g_AssignPrimitiveCreatorRegistry("Assign", AssignPrimitiveCreator); | |||
| RegistryMSOps g_AssignAddPrimitiveCreatorRegistry("AssignAdd", AssignAddPrimitiveCreator); | |||
| RegistryMSOps g_AvgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_AvgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_BatchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator); | |||
| RegistryMSOps g_BatchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator); | |||
| RegistryMSOps g_BatchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator); | |||
| RegistryMSOps g_BiasAddPrimitiveCreatorRegistry("BiasAdd", BiasAddPrimitiveCreator); | |||
| RegistryMSOps g_BNGradPrimitiveCreatorRegistry("BNGrad", BNGradPrimitiveCreator); | |||
| RegistryMSOps g_BroadcastToPrimitiveCreatorRegistry("BroadcastTo", BroadcastToPrimitiveCreator); | |||
| RegistryMSOps g_CastPrimitiveCreatorRegistry("Cast", CastPrimitiveCreator); | |||
| RegistryMSOps g_CeilPrimitiveCreatorRegistry("Ceil", CeilPrimitiveCreator); | |||
| RegistryMSOps g_ClipPrimitiveCreatorRegistry("Clip", ClipPrimitiveCreator); | |||
| RegistryMSOps g_ConcatPrimitiveCreatorRegistry("Concat", ConcatPrimitiveCreator); | |||
| // RegistryMSOps g_ControlDependPrimitiveCreatorRegistry("ControlDepend", ControlDependPrimitiveCreator); | |||
| RegistryMSOps g_Conv2DBackpropFilterFusionPrimitiveCreatorRegistry("Conv2DBackpropFilterFusion", | |||
| RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | |||
| RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | |||
| RegistryMSOps g_activationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator); | |||
| RegistryMSOps g_addPrimitiveCreatorRegistry("Add", AddFusionPrimitiveCreator); | |||
| RegistryMSOps g_addFusionPrimitiveCreatorRegistry("AddFusion", AddFusionPrimitiveCreator); | |||
| RegistryMSOps g_addGradPrimitiveCreatorRegistry("AddGrad", AddGradPrimitiveCreator); | |||
| RegistryMSOps g_adderPrimitiveCreatorRegistry("Adder", AdderFusionPrimitiveCreator); | |||
| RegistryMSOps g_adderFusionPrimitiveCreatorRegistry("AdderFusion", AdderFusionPrimitiveCreator); | |||
| RegistryMSOps g_addNPrimitiveCreatorRegistry("AddN", AddNPrimitiveCreator); | |||
| RegistryMSOps g_allPrimitiveCreatorRegistry("All", AllPrimitiveCreator); | |||
| RegistryMSOps g_applyMomentumPrimitiveCreatorRegistry("ApplyMomentum", ApplyMomentumPrimitiveCreator); | |||
| RegistryMSOps g_argMaxPrimitiveCreatorRegistry("ArgMax", ArgMaxFusionPrimitiveCreator); | |||
| RegistryMSOps g_argMaxFusionPrimitiveCreatorRegistry("ArgMaxFusion", ArgMaxFusionPrimitiveCreator); | |||
| RegistryMSOps g_argMinPrimitiveCreatorRegistry("ArgMin", ArgMinFusionPrimitiveCreator); | |||
| RegistryMSOps g_argMinFusionPrimitiveCreatorRegistry("ArgMinFusion", ArgMinFusionPrimitiveCreator); | |||
| RegistryMSOps g_assertPrimitiveCreatorRegistry("Assert", AssertPrimitiveCreator); | |||
| RegistryMSOps g_assignPrimitiveCreatorRegistry("Assign", AssignPrimitiveCreator); | |||
| RegistryMSOps g_assignAddPrimitiveCreatorRegistry("AssignAdd", AssignAddPrimitiveCreator); | |||
| RegistryMSOps g_avgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_avgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_batchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator); | |||
| RegistryMSOps g_batchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator); | |||
| RegistryMSOps g_batchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator); | |||
| RegistryMSOps g_biasAddPrimitiveCreatorRegistry("BiasAdd", BiasAddPrimitiveCreator); | |||
| RegistryMSOps g_bNGradPrimitiveCreatorRegistry("BNGrad", BNGradPrimitiveCreator); | |||
| RegistryMSOps g_broadcastToPrimitiveCreatorRegistry("BroadcastTo", BroadcastToPrimitiveCreator); | |||
| RegistryMSOps g_castPrimitiveCreatorRegistry("Cast", CastPrimitiveCreator); | |||
| RegistryMSOps g_ceilPrimitiveCreatorRegistry("Ceil", CeilPrimitiveCreator); | |||
| RegistryMSOps g_clipPrimitiveCreatorRegistry("Clip", ClipPrimitiveCreator); | |||
| RegistryMSOps g_concatPrimitiveCreatorRegistry("Concat", ConcatPrimitiveCreator); | |||
| // RegistryMSOps g_controlDependPrimitiveCreatorRegistry("ControlDepend", ControlDependPrimitiveCreator); | |||
| RegistryMSOps g_conv2DBackpropFilterFusionPrimitiveCreatorRegistry("Conv2DBackpropFilterFusion", | |||
| Conv2DBackpropFilterFusionPrimitiveCreator); | |||
| RegistryMSOps g_Conv2DBackpropInputFusionPrimitiveCreatorRegistry("Conv2DBackpropInputFusion", | |||
| RegistryMSOps g_conv2DBackpropInputFusionPrimitiveCreatorRegistry("Conv2DBackpropInputFusion", | |||
| Conv2DBackpropInputFusionPrimitiveCreator); | |||
| RegistryMSOps g_Conv2DPrimitiveCreatorRegistry("Conv2D", Conv2DFusionPrimitiveCreator); | |||
| RegistryMSOps g_Conv2DFusionPrimitiveCreatorRegistry("Conv2DFusion", Conv2DFusionPrimitiveCreator); | |||
| RegistryMSOps g_Conv2dTransposePrimitiveCreatorRegistry("Conv2dTranspose", Conv2dTransposeFusionPrimitiveCreator); | |||
| RegistryMSOps g_Conv2dTransposeFusionPrimitiveCreatorRegistry("Conv2dTransposeFusion", | |||
| RegistryMSOps g_conv2DPrimitiveCreatorRegistry("Conv2D", Conv2DFusionPrimitiveCreator); | |||
| RegistryMSOps g_conv2DFusionPrimitiveCreatorRegistry("Conv2DFusion", Conv2DFusionPrimitiveCreator); | |||
| RegistryMSOps g_conv2dTransposePrimitiveCreatorRegistry("Conv2dTranspose", Conv2dTransposeFusionPrimitiveCreator); | |||
| RegistryMSOps g_conv2dTransposeFusionPrimitiveCreatorRegistry("Conv2dTransposeFusion", | |||
| Conv2dTransposeFusionPrimitiveCreator); | |||
| RegistryMSOps g_ConstantOfShapePrimitiveCreatorRegistry("ConstantOfShape", ConstantOfShapePrimitiveCreator); | |||
| RegistryMSOps g_CosPrimitiveCreatorRegistry("Cos", CosPrimitiveCreator); | |||
| RegistryMSOps g_CropPrimitiveCreatorRegistry("Crop", CropPrimitiveCreator); | |||
| RegistryMSOps g_CustomExtractFeaturesPrimitiveCreatorRegistry("CustomExtractFeatures", | |||
| RegistryMSOps g_constantOfShapePrimitiveCreatorRegistry("ConstantOfShape", ConstantOfShapePrimitiveCreator); | |||
| RegistryMSOps g_cosPrimitiveCreatorRegistry("Cos", CosPrimitiveCreator); | |||
| RegistryMSOps g_cropPrimitiveCreatorRegistry("Crop", CropPrimitiveCreator); | |||
| RegistryMSOps g_customExtractFeaturesPrimitiveCreatorRegistry("CustomExtractFeatures", | |||
| CustomExtractFeaturesPrimitiveCreator); | |||
| RegistryMSOps g_CustomNormalizePrimitiveCreatorRegistry("CustomNormalize", CustomNormalizePrimitiveCreator); | |||
| RegistryMSOps g_CustomPredictPrimitiveCreatorRegistry("CustomPredict", CustomPredictPrimitiveCreator); | |||
| RegistryMSOps g_DependPrimitiveCreatorRegistry("Depend", DependPrimitiveCreator); | |||
| RegistryMSOps g_DepthToSpacePrimitiveCreatorRegistry("DepthToSpace", DepthToSpacePrimitiveCreator); | |||
| RegistryMSOps g_DetectionPostProcessPrimitiveCreatorRegistry("DetectionPostProcess", | |||
| RegistryMSOps g_customNormalizePrimitiveCreatorRegistry("CustomNormalize", CustomNormalizePrimitiveCreator); | |||
| RegistryMSOps g_customPredictPrimitiveCreatorRegistry("CustomPredict", CustomPredictPrimitiveCreator); | |||
| RegistryMSOps g_dependPrimitiveCreatorRegistry("Depend", DependPrimitiveCreator); | |||
| RegistryMSOps g_depthToSpacePrimitiveCreatorRegistry("DepthToSpace", DepthToSpacePrimitiveCreator); | |||
| RegistryMSOps g_detectionPostProcessPrimitiveCreatorRegistry("DetectionPostProcess", | |||
| DetectionPostProcessPrimitiveCreator); | |||
| RegistryMSOps g_DivPrimitiveCreatorRegistry("Div", DivFusionPrimitiveCreator); | |||
| RegistryMSOps g_DivFusionPrimitiveCreatorRegistry("DivFusion", DivFusionPrimitiveCreator); | |||
| RegistryMSOps g_DivGradPrimitiveCreatorRegistry("DivGrad", DivGradPrimitiveCreator); | |||
| RegistryMSOps g_DropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreator); | |||
| RegistryMSOps g_DropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator); | |||
| RegistryMSOps g_EltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator); | |||
| RegistryMSOps g_EluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator); | |||
| RegistryMSOps g_EqualPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator); | |||
| RegistryMSOps g_EmbeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion", | |||
| RegistryMSOps g_divPrimitiveCreatorRegistry("Div", DivFusionPrimitiveCreator); | |||
| RegistryMSOps g_divFusionPrimitiveCreatorRegistry("DivFusion", DivFusionPrimitiveCreator); | |||
| RegistryMSOps g_divGradPrimitiveCreatorRegistry("DivGrad", DivGradPrimitiveCreator); | |||
| RegistryMSOps g_dropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreator); | |||
| RegistryMSOps g_dropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator); | |||
| RegistryMSOps g_eltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator); | |||
| RegistryMSOps g_eluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator); | |||
| RegistryMSOps g_equalPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator); | |||
| RegistryMSOps g_embeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion", | |||
| EmbeddingLookupFusionPrimitiveCreator); | |||
| RegistryMSOps g_ExpandDimsPrimitiveCreatorRegistry("ExpandDims", ExpandDimsPrimitiveCreator); | |||
| RegistryMSOps g_ExpPrimitiveCreatorRegistry("Exp", ExpFusionPrimitiveCreator); | |||
| RegistryMSOps g_ExpFusionPrimitiveCreatorRegistry("ExpFusion", ExpFusionPrimitiveCreator); | |||
| RegistryMSOps g_FftImagPrimitiveCreatorRegistry("FftImag", FftImagPrimitiveCreator); | |||
| RegistryMSOps g_FftRealPrimitiveCreatorRegistry("FftReal", FftRealPrimitiveCreator); | |||
| RegistryMSOps g_FillPrimitiveCreatorRegistry("Fill", FillPrimitiveCreator); | |||
| RegistryMSOps g_FlattenPrimitiveCreatorRegistry("Flatten", FlattenPrimitiveCreator); | |||
| RegistryMSOps g_FlattenGradPrimitiveCreatorRegistry("FlattenGrad", FlattenGradPrimitiveCreator); | |||
| RegistryMSOps g_FloorPrimitiveCreatorRegistry("Floor", FloorPrimitiveCreator); | |||
| RegistryMSOps g_FloorDivPrimitiveCreatorRegistry("FloorDiv", FloorDivPrimitiveCreator); | |||
| RegistryMSOps g_FloorModPrimitiveCreatorRegistry("FloorMod", FloorModPrimitiveCreator); | |||
| RegistryMSOps g_FullConnectionPrimitiveCreatorRegistry("FullConnection", FullConnectionPrimitiveCreator); | |||
| RegistryMSOps g_FusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator); | |||
| RegistryMSOps g_GatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator); | |||
| RegistryMSOps g_GatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); | |||
| RegistryMSOps g_GreaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); | |||
| RegistryMSOps g_GreaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); | |||
| RegistryMSOps g_HashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator); | |||
| RegistryMSOps g_InstanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator); | |||
| RegistryMSOps g_LayerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator); | |||
| RegistryMSOps g_LayerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator); | |||
| RegistryMSOps g_LeakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator); | |||
| RegistryMSOps g_LessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator); | |||
| RegistryMSOps g_LessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator); | |||
| RegistryMSOps g_LogPrimitiveCreatorRegistry("Log", LogPrimitiveCreator); | |||
| RegistryMSOps g_LogGradPrimitiveCreatorRegistry("LogGrad", LogGradPrimitiveCreator); | |||
| RegistryMSOps g_LogicalAndPrimitiveCreatorRegistry("LogicalAnd", LogicalAndPrimitiveCreator); | |||
| RegistryMSOps g_LogicalNotPrimitiveCreatorRegistry("LogicalNot", LogicalNotPrimitiveCreator); | |||
| RegistryMSOps g_LogicalOrPrimitiveCreatorRegistry("LogicalOr", LogicalOrPrimitiveCreator); | |||
| RegistryMSOps g_LpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNormalizationPrimitiveCreator); | |||
| RegistryMSOps g_LrnPrimitiveCreatorRegistry("Lrn", LrnPrimitiveCreator); | |||
| RegistryMSOps g_LshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator); | |||
| RegistryMSOps g_LSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator); | |||
| RegistryMSOps g_L2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator); | |||
| RegistryMSOps g_MatMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator); | |||
| RegistryMSOps g_MaximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator); | |||
| RegistryMSOps g_MaximumGradPrimitiveCreatorRegistry("MaximumGrad", MaximumGradPrimitiveCreator); | |||
| RegistryMSOps g_MaxPoolPrimitiveCreatorRegistry("MaxPool", MaxPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_MaxPoolFusionPrimitiveCreatorRegistry("MaxPoolFusion", MaxPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_MergePrimitiveCreatorRegistry("Merge", MergePrimitiveCreator); | |||
| RegistryMSOps g_MinimumPrimitiveCreatorRegistry("Minimum", MinimumPrimitiveCreator); | |||
| RegistryMSOps g_MinimumGradPrimitiveCreatorRegistry("MinimumGrad", MinimumGradPrimitiveCreator); | |||
| RegistryMSOps g_ModPrimitiveCreatorRegistry("Mod", ModPrimitiveCreator); | |||
| RegistryMSOps g_MulPrimitiveCreatorRegistry("Mul", MulFusionPrimitiveCreator); | |||
| RegistryMSOps g_MulMulFusionPrimitiveCreatorRegistry("MulFusion", MulFusionPrimitiveCreator); | |||
| RegistryMSOps g_MulGradPrimitiveCreatorRegistry("MulGrad", MulGradPrimitiveCreator); | |||
| RegistryMSOps g_NegPrimitiveCreatorRegistry("Neg", NegPrimitiveCreator); | |||
| RegistryMSOps g_NegGradPrimitiveCreatorRegistry("NegGrad", NegGradPrimitiveCreator); | |||
| RegistryMSOps g_NonMaxSuppressionPrimitiveCreatorRegistry("NonMaxSuppression", NonMaxSuppressionPrimitiveCreator); | |||
| RegistryMSOps g_NotEqualPrimitiveCreatorRegistry("NotEqual", NotEqualPrimitiveCreator); | |||
| RegistryMSOps g_OneHotPrimitiveCreatorRegistry("OneHot", OneHotPrimitiveCreator); | |||
| RegistryMSOps g_OnesLikePrimitiveCreatorRegistry("OnesLike", OnesLikePrimitiveCreator); | |||
| RegistryMSOps g_PadPrimitiveCreatorRegistry("Pad", PadFusionPrimitiveCreator); | |||
| RegistryMSOps g_PadFusionPrimitiveCreatorRegistry("PadFusion", PadFusionPrimitiveCreator); | |||
| RegistryMSOps g_PartialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFusionPrimitiveCreator); | |||
| RegistryMSOps g_PowerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator); | |||
| RegistryMSOps g_PowFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator); | |||
| RegistryMSOps g_PReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator); | |||
| RegistryMSOps g_RangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator); | |||
| RegistryMSOps g_RankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator); | |||
| RegistryMSOps g_ReciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator); | |||
| RegistryMSOps g_RealDivPrimitiveCreatorRegistry("RealDiv", RealDivPrimitiveCreator); | |||
| RegistryMSOps g_ReducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCreator); | |||
| RegistryMSOps g_ReduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator); | |||
| RegistryMSOps g_ReshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); | |||
| RegistryMSOps g_ResizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); | |||
| RegistryMSOps g_ReverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); | |||
| RegistryMSOps g_ReverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); | |||
| RegistryMSOps g_RfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); | |||
| RegistryMSOps g_ROIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator); | |||
| RegistryMSOps g_RoundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator); | |||
| RegistryMSOps g_RsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator); | |||
| RegistryMSOps g_QuantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator); | |||
| RegistryMSOps g_ScalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator); | |||
| RegistryMSOps g_ScaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator); | |||
| RegistryMSOps g_ShapePrimitiveCreatorRegistry("Shape", ShapePrimitiveCreator); | |||
| RegistryMSOps g_SigmoidCrossEntropyWithLogitsPrimitiveCreatorRegistry("SigmoidCrossEntropyWithLogits", | |||
| RegistryMSOps g_expandDimsPrimitiveCreatorRegistry("ExpandDims", ExpandDimsPrimitiveCreator); | |||
| RegistryMSOps g_expPrimitiveCreatorRegistry("Exp", ExpFusionPrimitiveCreator); | |||
| RegistryMSOps g_expFusionPrimitiveCreatorRegistry("ExpFusion", ExpFusionPrimitiveCreator); | |||
| RegistryMSOps g_fftImagPrimitiveCreatorRegistry("FftImag", FftImagPrimitiveCreator); | |||
| RegistryMSOps g_fftRealPrimitiveCreatorRegistry("FftReal", FftRealPrimitiveCreator); | |||
| RegistryMSOps g_fillPrimitiveCreatorRegistry("Fill", FillPrimitiveCreator); | |||
| RegistryMSOps g_flattenPrimitiveCreatorRegistry("Flatten", FlattenPrimitiveCreator); | |||
| RegistryMSOps g_flattenGradPrimitiveCreatorRegistry("FlattenGrad", FlattenGradPrimitiveCreator); | |||
| RegistryMSOps g_floorPrimitiveCreatorRegistry("Floor", FloorPrimitiveCreator); | |||
| RegistryMSOps g_floorDivPrimitiveCreatorRegistry("FloorDiv", FloorDivPrimitiveCreator); | |||
| RegistryMSOps g_floorModPrimitiveCreatorRegistry("FloorMod", FloorModPrimitiveCreator); | |||
| RegistryMSOps g_fullConnectionPrimitiveCreatorRegistry("FullConnection", FullConnectionPrimitiveCreator); | |||
| RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator); | |||
| RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator); | |||
| RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); | |||
| RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); | |||
| RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); | |||
| RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator); | |||
| RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator); | |||
| RegistryMSOps g_layerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator); | |||
| RegistryMSOps g_layerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator); | |||
| RegistryMSOps g_leakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator); | |||
| RegistryMSOps g_lessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator); | |||
| RegistryMSOps g_lessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator); | |||
| RegistryMSOps g_logPrimitiveCreatorRegistry("Log", LogPrimitiveCreator); | |||
| RegistryMSOps g_logGradPrimitiveCreatorRegistry("LogGrad", LogGradPrimitiveCreator); | |||
| RegistryMSOps g_logicalAndPrimitiveCreatorRegistry("LogicalAnd", LogicalAndPrimitiveCreator); | |||
| RegistryMSOps g_logicalNotPrimitiveCreatorRegistry("LogicalNot", LogicalNotPrimitiveCreator); | |||
| RegistryMSOps g_logicalOrPrimitiveCreatorRegistry("LogicalOr", LogicalOrPrimitiveCreator); | |||
| RegistryMSOps g_lpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNormalizationPrimitiveCreator); | |||
| RegistryMSOps g_lrnPrimitiveCreatorRegistry("Lrn", LrnPrimitiveCreator); | |||
| RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator); | |||
| RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator); | |||
| RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator); | |||
| RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator); | |||
| RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator); | |||
| RegistryMSOps g_maximumGradPrimitiveCreatorRegistry("MaximumGrad", MaximumGradPrimitiveCreator); | |||
| RegistryMSOps g_maxPoolPrimitiveCreatorRegistry("MaxPool", MaxPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_maxPoolFusionPrimitiveCreatorRegistry("MaxPoolFusion", MaxPoolFusionPrimitiveCreator); | |||
| RegistryMSOps g_mergePrimitiveCreatorRegistry("Merge", MergePrimitiveCreator); | |||
| RegistryMSOps g_minimumPrimitiveCreatorRegistry("Minimum", MinimumPrimitiveCreator); | |||
| RegistryMSOps g_minimumGradPrimitiveCreatorRegistry("MinimumGrad", MinimumGradPrimitiveCreator); | |||
| RegistryMSOps g_modPrimitiveCreatorRegistry("Mod", ModPrimitiveCreator); | |||
| RegistryMSOps g_mulPrimitiveCreatorRegistry("Mul", MulFusionPrimitiveCreator); | |||
| RegistryMSOps g_mulMulFusionPrimitiveCreatorRegistry("MulFusion", MulFusionPrimitiveCreator); | |||
| RegistryMSOps g_mulGradPrimitiveCreatorRegistry("MulGrad", MulGradPrimitiveCreator); | |||
| RegistryMSOps g_negPrimitiveCreatorRegistry("Neg", NegPrimitiveCreator); | |||
| RegistryMSOps g_negGradPrimitiveCreatorRegistry("NegGrad", NegGradPrimitiveCreator); | |||
| RegistryMSOps g_nonMaxSuppressionPrimitiveCreatorRegistry("NonMaxSuppression", NonMaxSuppressionPrimitiveCreator); | |||
| RegistryMSOps g_notEqualPrimitiveCreatorRegistry("NotEqual", NotEqualPrimitiveCreator); | |||
| RegistryMSOps g_oneHotPrimitiveCreatorRegistry("OneHot", OneHotPrimitiveCreator); | |||
| RegistryMSOps g_onesLikePrimitiveCreatorRegistry("OnesLike", OnesLikePrimitiveCreator); | |||
| RegistryMSOps g_padPrimitiveCreatorRegistry("Pad", PadFusionPrimitiveCreator); | |||
| RegistryMSOps g_padFusionPrimitiveCreatorRegistry("PadFusion", PadFusionPrimitiveCreator); | |||
| RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFusionPrimitiveCreator); | |||
| RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator); | |||
| RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator); | |||
| RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator); | |||
| RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator); | |||
| RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator); | |||
| RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator); | |||
| RegistryMSOps g_realDivPrimitiveCreatorRegistry("RealDiv", RealDivPrimitiveCreator); | |||
| RegistryMSOps g_reducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCreator); | |||
| RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator); | |||
| RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); | |||
| RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); | |||
| RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); | |||
| RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); | |||
| RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); | |||
| RegistryMSOps g_rOIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator); | |||
| RegistryMSOps g_roundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator); | |||
| RegistryMSOps g_rsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator); | |||
| RegistryMSOps g_quantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator); | |||
| RegistryMSOps g_scalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator); | |||
| RegistryMSOps g_scaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator); | |||
| RegistryMSOps g_shapePrimitiveCreatorRegistry("Shape", ShapePrimitiveCreator); | |||
| RegistryMSOps g_sigmoidCrossEntropyWithLogitsPrimitiveCreatorRegistry("SigmoidCrossEntropyWithLogits", | |||
| SigmoidCrossEntropyWithLogitsPrimitiveCreator); | |||
| RegistryMSOps g_SigmoidCrossEntropyWithLogitsGradPrimitiveCreatorRegistry( | |||
| RegistryMSOps g_sigmoidCrossEntropyWithLogitsGradPrimitiveCreatorRegistry( | |||
| "SigmoidCrossEntropyWithLogitsGrad", SigmoidCrossEntropyWithLogitsGradPrimitiveCreator); | |||
| RegistryMSOps g_SinPrimitiveCreatorRegistry("Sin", SinPrimitiveCreator); | |||
| RegistryMSOps g_SkipGramPrimitiveCreatorRegistry("SkipGram", SkipGramPrimitiveCreator); | |||
| RegistryMSOps g_SliceFusionPrimitiveCreatorRegistry("SliceFusion", SliceFusionPrimitiveCreator); | |||
| RegistryMSOps g_SmoothL1LossPrimitiveCreatorRegistry("SmoothL1Loss", SmoothL1LossPrimitiveCreator); | |||
| RegistryMSOps g_SmoothL1LossGradPrimitiveCreatorRegistry("SmoothL1LossGrad", SmoothL1LossGradPrimitiveCreator); | |||
| RegistryMSOps g_SoftmaxPrimitiveCreatorRegistry("Softmax", SoftmaxPrimitiveCreator); | |||
| RegistryMSOps g_SpaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator); | |||
| RegistryMSOps g_SpaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator); | |||
| RegistryMSOps g_SpaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator); | |||
| RegistryMSOps g_SparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); | |||
| RegistryMSOps g_SplitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); | |||
| RegistryMSOps g_SqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); | |||
| RegistryMSOps g_SqueezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator); | |||
| RegistryMSOps g_SquarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator); | |||
| RegistryMSOps g_SquaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator); | |||
| RegistryMSOps g_StackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator); | |||
| RegistryMSOps g_StridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator); | |||
| RegistryMSOps g_SubPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator); | |||
| RegistryMSOps g_SubFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator); | |||
| RegistryMSOps g_SubGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator); | |||
| RegistryMSOps g_SwitchPrimitiveCreatorRegistry("Switch", SwitchPrimitiveCreator); | |||
| RegistryMSOps g_TensorListFromTensorPrimitiveCreatorRegistry("TensorListFromTensor", | |||
| RegistryMSOps g_sinPrimitiveCreatorRegistry("Sin", SinPrimitiveCreator); | |||
| RegistryMSOps g_skipGramPrimitiveCreatorRegistry("SkipGram", SkipGramPrimitiveCreator); | |||
| RegistryMSOps g_sliceFusionPrimitiveCreatorRegistry("SliceFusion", SliceFusionPrimitiveCreator); | |||
| RegistryMSOps g_smoothL1LossPrimitiveCreatorRegistry("SmoothL1Loss", SmoothL1LossPrimitiveCreator); | |||
| RegistryMSOps g_smoothL1LossGradPrimitiveCreatorRegistry("SmoothL1LossGrad", SmoothL1LossGradPrimitiveCreator); | |||
| RegistryMSOps g_softmaxPrimitiveCreatorRegistry("Softmax", SoftmaxPrimitiveCreator); | |||
| RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator); | |||
| RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator); | |||
| RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator); | |||
| RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); | |||
| RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); | |||
| RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); | |||
| RegistryMSOps g_squeezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator); | |||
| RegistryMSOps g_squarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator); | |||
| RegistryMSOps g_squaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator); | |||
| RegistryMSOps g_stackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator); | |||
| RegistryMSOps g_stridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator); | |||
| RegistryMSOps g_subPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator); | |||
| RegistryMSOps g_subFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator); | |||
| RegistryMSOps g_subGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator); | |||
| RegistryMSOps g_switchPrimitiveCreatorRegistry("Switch", SwitchPrimitiveCreator); | |||
| RegistryMSOps g_tensorListFromTensorPrimitiveCreatorRegistry("TensorListFromTensor", | |||
| TensorListFromTensorPrimitiveCreator); | |||
| RegistryMSOps g_TensorListGetItemPrimitiveCreatorRegistry("TensorListGetItem", TensorListGetItemPrimitiveCreator); | |||
| RegistryMSOps g_TensorListReservePrimitiveCreatorRegistry("TensorListReserve", TensorListReservePrimitiveCreator); | |||
| RegistryMSOps g_TensorListSetItemPrimitiveCreatorRegistry("TensorListSetItem", TensorListSetItemPrimitiveCreator); | |||
| RegistryMSOps g_TensorListStackPrimitiveCreatorRegistry("TensorListStack", TensorListStackPrimitiveCreator); | |||
| RegistryMSOps g_TileFusionPrimitiveCreatorRegistry("TileFusion", TileFusionPrimitiveCreator); | |||
| RegistryMSOps g_TopKPrimitiveCreatorRegistry("TopK", TopKFusionPrimitiveCreator); | |||
| RegistryMSOps g_TopKFusionPrimitiveCreatorRegistry("TopKFusion", TopKFusionPrimitiveCreator); | |||
| RegistryMSOps g_TransposePrimitiveCreatorxRegistry("Transpose", TransposePrimitiveCreator); | |||
| RegistryMSOps g_UniquePrimitiveCreatorRegistry("Unique", UniquePrimitiveCreator); | |||
| RegistryMSOps g_UnpackPrimitiveCreatorRegistry("Unpack", UnpackPrimitiveCreator); | |||
| RegistryMSOps g_UnsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum", UnsortedSegmentSumPrimitiveCreator); | |||
| RegistryMSOps g_UnsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator); | |||
| RegistryMSOps g_WherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); | |||
| RegistryMSOps g_ZerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); | |||
| RegistryMSOps g_tensorListGetItemPrimitiveCreatorRegistry("TensorListGetItem", TensorListGetItemPrimitiveCreator); | |||
| RegistryMSOps g_tensorListReservePrimitiveCreatorRegistry("TensorListReserve", TensorListReservePrimitiveCreator); | |||
| RegistryMSOps g_tensorListSetItemPrimitiveCreatorRegistry("TensorListSetItem", TensorListSetItemPrimitiveCreator); | |||
| RegistryMSOps g_tensorListStackPrimitiveCreatorRegistry("TensorListStack", TensorListStackPrimitiveCreator); | |||
| RegistryMSOps g_tileFusionPrimitiveCreatorRegistry("TileFusion", TileFusionPrimitiveCreator); | |||
| RegistryMSOps g_topKPrimitiveCreatorRegistry("TopK", TopKFusionPrimitiveCreator); | |||
| RegistryMSOps g_topKFusionPrimitiveCreatorRegistry("TopKFusion", TopKFusionPrimitiveCreator); | |||
| RegistryMSOps g_transposePrimitiveCreatorxRegistry("Transpose", TransposePrimitiveCreator); | |||
| RegistryMSOps g_uniquePrimitiveCreatorRegistry("Unique", UniquePrimitiveCreator); | |||
| RegistryMSOps g_unpackPrimitiveCreatorRegistry("Unpack", UnpackPrimitiveCreator); | |||
| RegistryMSOps g_unsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum", UnsortedSegmentSumPrimitiveCreator); | |||
| RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator); | |||
| RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); | |||
| RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,59 +21,43 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ReLU failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::RELU); | |||
| prim->set_activation_type(mindspore::ActivationType::RELU); | |||
| if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) { | |||
| float negative_slope = proto.relu_param().negative_slope(); | |||
| if (negative_slope != 0) { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::LEAKY_RELU); | |||
| primitive_c->set_alpha(negative_slope); | |||
| prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU); | |||
| prim->set_alpha(negative_slope); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *CaffeRelu6Parser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Relu6 failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::RELU6); | |||
| prim->set_activation_type(mindspore::ActivationType::RELU6); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *CaffeSigmoidParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Sigmoid failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::SIGMOID); | |||
| prim->set_activation_type(mindspore::ActivationType::SIGMOID); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *CaffeTanhParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Tanh failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::TANH); | |||
| prim->set_activation_type(mindspore::ActivationType::TANH); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeReluParser("ReLU", new CaffeReluParser()); | |||
| @@ -21,28 +21,24 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::ArgMaxFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ArgMaxFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ArgMaxFusion>(); | |||
| primitive_c->set_keep_dims(true); | |||
| primitive_c->set_out_max_value(false); | |||
| primitive_c->set_top_k(1); | |||
| prim->set_keep_dims(true); | |||
| prim->set_out_max_value(false); | |||
| prim->set_top_k(1); | |||
| const caffe::ArgMaxParameter &argmaxParam = proto.argmax_param(); | |||
| if (argmaxParam.has_out_max_val()) { | |||
| primitive_c->set_out_max_value(argmaxParam.out_max_val()); | |||
| prim->set_out_max_value(argmaxParam.out_max_val()); | |||
| } | |||
| if (argmaxParam.has_top_k()) { | |||
| primitive_c->set_top_k(argmaxParam.top_k()); | |||
| prim->set_top_k(argmaxParam.top_k()); | |||
| } | |||
| if (argmaxParam.has_axis()) { | |||
| primitive_c->set_axis(argmaxParam.axis()); | |||
| prim->set_axis(argmaxParam.axis()); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeArgMaxParser("ArgMax", new CaffeArgMaxParser()); | |||
| @@ -24,11 +24,10 @@ namespace mindspore { | |||
| namespace lite { | |||
| using STATUS = int; | |||
| ops::PrimitiveC *CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::BatchNorm(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new BatchNorm failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::BatchNorm>(); | |||
| prim->set_is_training(false); | |||
| prim->set_format(mindspore::NCHW); | |||
| const caffe::BatchNormParameter &batchNormParam = proto.batch_norm_param(); | |||
| if (proto.bottom_size() != 1) { | |||
| @@ -46,12 +45,9 @@ ops::PrimitiveC *CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, | |||
| if (batchNormParam.has_eps() && std::fabs(1e-5 - batchNormParam.eps()) >= 1e-9) { | |||
| epsilon = batchNormParam.eps(); | |||
| } | |||
| primitive_c->set_epsilon(epsilon); | |||
| primitive_c->set_is_training(false); | |||
| primitive_c->set_format(mindspore::NCHW); | |||
| prim->set_epsilon(epsilon); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeBatchNormParser("BatchNorm", new CaffeBatchNormParser()); | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Concat(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Concat failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Concat>(); | |||
| const caffe::ConcatParameter &concatParam = proto.concat_param(); | |||
| if (concatParam.has_axis() && concatParam.has_concat_dim()) { | |||
| @@ -45,9 +41,9 @@ ops::PrimitiveC *CaffeConcatParser::Parse(const caffe::LayerParameter &proto, co | |||
| MS_LOG(DEBUG) << "set axis: " << concatParam.axis(); | |||
| axis = concatParam.axis(); | |||
| } | |||
| primitive_c->set_axis(axis); | |||
| prim->set_axis(axis); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeConcatParser("Concat", new CaffeConcatParser()); | |||
| @@ -22,61 +22,52 @@ namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, | |||
| const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Conv2DFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Conv2DFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Conv2DFusion>(); | |||
| primitive_c->set_pad({0, 0, 0, 0}); | |||
| primitive_c->set_pad_mode(mindspore::PadMode::PAD); | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| primitive_c->set_activation_type(mindspore::NO_ACTIVATION); | |||
| prim->set_pad({0, 0, 0, 0}); | |||
| prim->set_pad_mode(mindspore::PadMode::PAD); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| prim->set_activation_type(mindspore::NO_ACTIVATION); | |||
| const caffe::ConvolutionParameter &convParam = proto.convolution_param(); | |||
| // parse kernel | |||
| std::vector<int64_t> kernel(2, 0); | |||
| if (CaffeConvBaseParser::ParseKernels(convParam, &kernel) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_kernel_size(kernel); | |||
| prim->set_kernel_size(kernel); | |||
| // parse stride | |||
| std::vector<int64_t> stride(2, 0); | |||
| if (CaffeConvBaseParser::ParseStrides(convParam, &stride) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_stride(stride); | |||
| prim->set_stride(stride); | |||
| // parse dilation | |||
| std::vector<int64_t> dilation(2, 0); | |||
| if (CaffeConvBaseParser::ParseDilations(convParam, &dilation) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_dilation(dilation); | |||
| prim->set_dilation(dilation); | |||
| // parse pad | |||
| std::vector<int64_t> pad(4, 0); | |||
| if (CaffeConvBaseParser::ParsePads(convParam, &pad) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_pad_list(pad); | |||
| prim->set_pad_list(pad); | |||
| // parse channelOut | |||
| int channel_out = 0; | |||
| if (CaffeConvBaseParser::ParseChannelOut(convParam, &channel_out) != RET_OK) { | |||
| MS_LOG(ERROR) << "conv channel out failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_out_channel(channel_out); | |||
| prim->set_out_channel(channel_out); | |||
| // parse group | |||
| auto group = CaffeConvBaseParser::ParseGroup(convParam, proto.type()); | |||
| primitive_c->set_group(group); | |||
| prim->set_group(group); | |||
| // parse channelIn | |||
| if (weight.blobs_size() < 1) { | |||
| @@ -85,11 +76,13 @@ ops::PrimitiveC *CaffeConvolutionParser::Parse(const caffe::LayerParameter &prot | |||
| } | |||
| auto &weightBlob = weight.blobs(0); | |||
| auto channelIn = weightBlob.has_shape() ? weightBlob.shape().dim(1) * group : weightBlob.channels() * group; | |||
| primitive_c->set_in_channel(channelIn); | |||
| prim->set_in_channel(channelIn); | |||
| if (group != 1) { | |||
| primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeConvolutionParser("Convolution", new CaffeConvolutionParser()); | |||
| @@ -21,25 +21,21 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeCropParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Crop(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Crop failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Crop>(); | |||
| if (!proto.has_crop_param()) { | |||
| primitive_c->set_axis(2); | |||
| prim->set_axis(2); | |||
| std::vector<int64_t> offsets(2, 0); | |||
| primitive_c->set_offsets(offsets); | |||
| prim->set_offsets(offsets); | |||
| } else { | |||
| const caffe::CropParameter &cropParam = proto.crop_param(); | |||
| if (cropParam.has_axis()) { | |||
| if (cropParam.axis() == -1) { | |||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| primitive_c->set_axis(cropParam.axis()); | |||
| prim->set_axis(cropParam.axis()); | |||
| } else { | |||
| primitive_c->set_axis(2); | |||
| prim->set_axis(2); | |||
| } | |||
| if (cropParam.offset_size() != 0) { | |||
| @@ -48,11 +44,11 @@ ops::PrimitiveC *CaffeCropParser::Parse(const caffe::LayerParameter &proto, cons | |||
| for (int i = 0; i < cropParam.offset_size(); i++) { | |||
| offsets.push_back(cropParam.offset(i)); | |||
| } | |||
| primitive_c->set_offsets(offsets); | |||
| prim->set_offsets(offsets); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeCropParser("Crop", new CaffeCropParser()); | |||
| @@ -20,78 +20,69 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, | |||
| const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Conv2dTransposeFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Conv2dTransposeFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Conv2dTransposeFusion>(); | |||
| primitive_c->set_pad({0, 0, 0, 0}); | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| primitive_c->set_pad_mode(mindspore::PadMode::PAD); | |||
| prim->set_pad({0, 0, 0, 0}); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| prim->set_pad_mode(mindspore::PadMode::PAD); | |||
| const caffe::ConvolutionParameter &convParam = proto.convolution_param(); | |||
| // parse pad | |||
| std::vector<int64_t> pad(4, 0); | |||
| if (CaffeConvBaseParser::ParsePads(convParam, &pad) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParsePads for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_pad_list({pad[0], pad[1], pad[2], pad[3]}); | |||
| prim->set_pad_list({pad[0], pad[1], pad[2], pad[3]}); | |||
| // parse stride | |||
| std::vector<int64_t> stride(2, 0); | |||
| if (CaffeConvBaseParser::ParseStrides(convParam, &stride) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseStrides for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_stride({stride[0], stride[1]}); | |||
| prim->set_stride({stride[0], stride[1]}); | |||
| // parse dilation | |||
| std::vector<int64_t> dilation(2, 0); | |||
| if (CaffeConvBaseParser::ParseDilations(convParam, &dilation) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseDilations for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_dilation({dilation[0], dilation[1]}); | |||
| prim->set_dilation({dilation[0], dilation[1]}); | |||
| // parse kernel | |||
| std::vector<int64_t> kernel(2, 0); | |||
| if (CaffeConvBaseParser::ParseKernels(convParam, &kernel) != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseKernels for " << proto.name().c_str() << " failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_kernel_size({kernel[0], kernel[1]}); | |||
| prim->set_kernel_size({kernel[0], kernel[1]}); | |||
| // parse group | |||
| auto group = CaffeConvBaseParser::ParseGroup(convParam, proto.type()); | |||
| primitive_c->set_group(group); | |||
| prim->set_group(group); | |||
| // parse channelOut | |||
| int32_t channelOut; | |||
| if (CaffeConvBaseParser::ParseChannelOut(convParam, &channelOut) != RET_OK) { | |||
| MS_LOG(ERROR) << "deconv channel get failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_out_channel((int64_t)channelOut); | |||
| prim->set_out_channel((int64_t)channelOut); | |||
| // parse channelIN | |||
| auto &weightBlob = weight.blobs(0); | |||
| if (weightBlob.has_shape()) { | |||
| if (group == 1) | |||
| primitive_c->set_in_channel(weightBlob.shape().dim(0) * group); | |||
| prim->set_in_channel(weightBlob.shape().dim(0) * group); | |||
| else | |||
| primitive_c->set_in_channel(weightBlob.shape().dim(1) * group); | |||
| prim->set_in_channel(weightBlob.shape().dim(1) * group); | |||
| } else { | |||
| primitive_c->set_in_channel(weightBlob.num() * group); | |||
| prim->set_in_channel(weightBlob.num() * group); | |||
| } | |||
| if (group != 1) { | |||
| primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeDeconvolutionParser("Deconvolution", new CaffeDeconvolutionParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Eltwise(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Eltwise failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Eltwise>(); | |||
| if (proto.bottom_size() < 2) { | |||
| MS_LOG(ERROR) << "Eltwise Op " << proto.name() << " need at least 2 inputs,but input size is " | |||
| @@ -55,23 +51,23 @@ ops::PrimitiveC *CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, c | |||
| if (proto.has_eltwise_param() && eltwiseParam.has_operation()) { | |||
| switch (eltwiseParam.operation()) { | |||
| case caffe::EltwiseParameter::PROD: | |||
| primitive_c->set_mode(mindspore::EltwiseMode::PROD); | |||
| prim->set_mode(mindspore::EltwiseMode::PROD); | |||
| break; | |||
| case caffe::EltwiseParameter::SUM: | |||
| primitive_c->set_mode(mindspore::EltwiseMode::SUM); | |||
| prim->set_mode(mindspore::EltwiseMode::SUM); | |||
| break; | |||
| case caffe::EltwiseParameter::MAX: | |||
| primitive_c->set_mode(mindspore::EltwiseMode::MAXIMUM); | |||
| prim->set_mode(mindspore::EltwiseMode::MAXIMUM); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Eltwise parse params fail, unsupported operation: " << eltwiseParam.operation(); | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| primitive_c->set_mode(mindspore::EltwiseMode::SUM); | |||
| prim->set_mode(mindspore::EltwiseMode::SUM); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeEltwiseParser("Eltwise", new CaffeEltwiseParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Elu(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Elu failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Elu>(); | |||
| if (proto.has_elu_param()) { | |||
| const caffe::ELUParameter &eluParameter = proto.elu_param(); | |||
| if (eluParameter.has_alpha()) { | |||
| primitive_c->set_alpha(eluParameter.alpha()); | |||
| prim->set_alpha(eluParameter.alpha()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser()); | |||
| @@ -22,29 +22,26 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeExpParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::ExpFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ExpFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ExpFusion>(); | |||
| const caffe::ExpParameter &exp_param = proto.exp_param(); | |||
| if (exp_param.has_base()) { | |||
| primitive_c->set_base(exp_param.base()); | |||
| prim->set_base(exp_param.base()); | |||
| } else { | |||
| primitive_c->set_base(-1); // -1 represent base = e | |||
| prim->set_base(-1); // -1 represent base = e | |||
| } | |||
| if (exp_param.has_scale()) { | |||
| primitive_c->set_scale(exp_param.scale()); | |||
| prim->set_scale(exp_param.scale()); | |||
| } else { | |||
| primitive_c->set_scale(1); | |||
| prim->set_scale(1); | |||
| } | |||
| if (exp_param.has_shift()) { | |||
| primitive_c->set_shift(exp_param.shift()); | |||
| prim->set_shift(exp_param.shift()); | |||
| } else { | |||
| primitive_c->set_shift(0); | |||
| prim->set_shift(0); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeExpParser("Exp", new CaffeExpParser()); | |||
| @@ -21,13 +21,9 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Flatten(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Flatten failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Flatten>(); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_CaffeFlattenParser("Flatten", new CaffeFlattenParser()); | |||
| @@ -22,11 +22,9 @@ namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, | |||
| const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::FullConnection(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new FullConnection failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::FullConnection>(); | |||
| prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION); | |||
| const caffe::InnerProductParameter &innerProductParam = proto.inner_product_param(); | |||
| if (!innerProductParam.has_num_output()) { | |||
| @@ -35,19 +33,17 @@ ops::PrimitiveC *CaffeInnerProductParser::Parse(const caffe::LayerParameter &pro | |||
| } | |||
| if (innerProductParam.axis() == 1) { | |||
| primitive_c->set_axis(1); | |||
| primitive_c->set_use_axis(true); | |||
| prim->set_axis(1); | |||
| prim->set_use_axis(true); | |||
| } else { | |||
| MS_LOG(ERROR) << "InnerProduct Parse axis only support default 1, but actually " << innerProductParam.axis(); | |||
| return nullptr; | |||
| } | |||
| if (innerProductParam.bias_term()) { | |||
| primitive_c->set_has_bias(true); | |||
| prim->set_has_bias(true); | |||
| } | |||
| primitive_c->set_activation_type(mindspore::ActivationType::NO_ACTIVATION); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeInnerProductParser("InnerProduct", new CaffeInnerProductParser()); | |||
| @@ -21,11 +21,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Resize(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Resize failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Resize>(); | |||
| prim->set_method(mindspore::ResizeMethod::LINEAR); | |||
| prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ALIGN_CORNERS); | |||
| const caffe::InterpParameter &interpParam = proto.interp_param(); | |||
| if (interpParam.has_height()) { | |||
| @@ -34,7 +33,7 @@ ops::PrimitiveC *CaffeInterpParser::Parse(const caffe::LayerParameter &proto, co | |||
| MS_LOG(ERROR) << "Interp height must be > 0"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_new_height(height); | |||
| prim->set_new_height(height); | |||
| } | |||
| if (interpParam.has_width()) { | |||
| @@ -43,12 +42,10 @@ ops::PrimitiveC *CaffeInterpParser::Parse(const caffe::LayerParameter &proto, co | |||
| MS_LOG(ERROR) << "Interp width must be > 0"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_new_width(width); | |||
| prim->set_new_width(width); | |||
| } | |||
| primitive_c->set_method(mindspore::ResizeMethod::LINEAR); | |||
| primitive_c->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ALIGN_CORNERS); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeInterpParser("Interp", new CaffeInterpParser()); | |||
| @@ -95,8 +95,8 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| continue; | |||
| } | |||
| auto primitive_c = node_parser->Parse(layer, weight); | |||
| if (primitive_c == nullptr) { | |||
| auto prim = node_parser->Parse(layer, weight); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << layer.name() << " failed."; | |||
| status = RET_ERROR; | |||
| continue; | |||
| @@ -119,7 +119,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| } | |||
| // build cnode | |||
| std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))}; | |||
| std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(prim))}; | |||
| op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end()); | |||
| op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end()); | |||
| auto new_cnode = func_graph_ptr_->NewCNode(op_inputs); | |||
| @@ -132,7 +132,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| continue; | |||
| } | |||
| status = ConvertLayerQuantParams(layer, weight, primitive_c); | |||
| status = ConvertLayerQuantParams(layer, weight, prim); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed."; | |||
| continue; | |||
| @@ -294,9 +294,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { | |||
| } | |||
| STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &layer, | |||
| const caffe::LayerParameter &weight, ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | |||
| const caffe::LayerParameter &weight, ops::PrimitiveC *prim) { | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is null, get quant params failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| @@ -312,7 +312,7 @@ STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &la | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| quant_params_holder->AddOutputQuantParam(notinited_quant_params); | |||
| } | |||
| primitive_c->AddAttr("quant_params", quant_params_holder); | |||
| prim->AddAttr("quant_params", quant_params_holder); | |||
| return RET_OK; | |||
| } | |||
| @@ -45,7 +45,7 @@ class CaffeModelParser : public ModelParser { | |||
| STATUS ConvertLayers(); | |||
| static STATUS ConvertLayerQuantParams(const caffe::LayerParameter &layer, const caffe::LayerParameter &weight, | |||
| ops::PrimitiveC *primitive_c); | |||
| ops::PrimitiveC *prim); | |||
| STATUS ConvertBlobs(const caffe::LayerParameter &layer, std::vector<ParameterPtr> *const_parameters); | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffePermuteParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Transpose(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Transpose failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Transpose>(); | |||
| std::vector<int32_t> perm; | |||
| const caffe::PermuteParameter &permuteParam = proto.permute_param(); | |||
| @@ -34,9 +30,9 @@ ops::PrimitiveC *CaffePermuteParser::Parse(const caffe::LayerParameter &proto, c | |||
| for (int i = 0; i < num_order_dims; ++i) { | |||
| perm[i] = permuteParam.order()[i]; | |||
| } | |||
| primitive_c->AddAttr("perm", MakeValue(perm)); | |||
| prim->AddAttr("perm", MakeValue(perm)); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffePermuteParser("Permute", new CaffePermuteParser()); | |||
| @@ -124,31 +124,23 @@ ops::PrimitiveC *CaffePoolingParser::Parse(const caffe::LayerParameter &proto, c | |||
| auto roundMode = ParseRoundMode(poolingParam); | |||
| if (poolingParam.pool() == caffe::PoolingParameter::MAX) { | |||
| auto primitive_c = new (std::nothrow) ops::MaxPoolFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MaxPoolFusion failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| primitive_c->set_pad_mode(mindspore::PadMode::PAD); | |||
| primitive_c->set_kernel_size(windows); | |||
| primitive_c->set_strides(strides); | |||
| primitive_c->set_pad(pad); | |||
| primitive_c->set_round_mode(roundMode); | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::MaxPoolFusion>(); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| prim->set_pad_mode(mindspore::PadMode::PAD); | |||
| prim->set_kernel_size(windows); | |||
| prim->set_strides(strides); | |||
| prim->set_pad(pad); | |||
| prim->set_round_mode(roundMode); | |||
| return prim.release(); | |||
| } else if (poolingParam.pool() == caffe::PoolingParameter::AVE) { | |||
| auto primitive_c = new (std::nothrow) ops::AvgPoolFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new AvgPoolFusion failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| primitive_c->set_pad_mode(mindspore::PadMode::PAD); | |||
| primitive_c->set_kernel_size(windows); | |||
| primitive_c->set_strides(strides); | |||
| primitive_c->set_pad(pad); | |||
| primitive_c->set_round_mode(roundMode); | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::AvgPoolFusion>(); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| prim->set_pad_mode(mindspore::PadMode::PAD); | |||
| prim->set_kernel_size(windows); | |||
| prim->set_strides(strides); | |||
| prim->set_pad(pad); | |||
| prim->set_round_mode(roundMode); | |||
| return prim.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << "poolingParam.pool() is not MAX or AVE"; | |||
| return nullptr; | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffePowerParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::PowFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new PowFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::PowFusion>(); | |||
| const caffe::PowerParameter &powerParam = proto.power_param(); | |||
| float power = 1.0; | |||
| @@ -42,11 +38,11 @@ ops::PrimitiveC *CaffePowerParser::Parse(const caffe::LayerParameter &proto, con | |||
| shift = powerParam.shift(); | |||
| } | |||
| } | |||
| primitive_c->AddAttr("power", MakeValue(power)); | |||
| primitive_c->set_scale(scale); | |||
| primitive_c->set_shift(shift); | |||
| prim->AddAttr("power", MakeValue(power)); | |||
| prim->set_scale(scale); | |||
| prim->set_shift(shift); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffePowerParser("Power", new CaffePowerParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffePReluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::PReLUFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new PReLUFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::PReLUFusion>(); | |||
| const caffe::PReLUParameter &pReluParam = proto.prelu_param(); | |||
| if (pReluParam.has_channel_shared()) { | |||
| primitive_c->set_channel_shared(pReluParam.channel_shared()); | |||
| const caffe::PReLUParameter &prelu_param = proto.prelu_param(); | |||
| if (prelu_param.has_channel_shared()) { | |||
| prim->set_channel_shared(prelu_param.channel_shared()); | |||
| } else { | |||
| primitive_c->set_channel_shared(false); | |||
| prim->set_channel_shared(false); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffePReluParser("PReLU", new CaffePReluParser()); | |||
| @@ -22,30 +22,26 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::ReduceFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ReduceFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ReduceFusion>(); | |||
| primitive_c->set_keep_dims(false); | |||
| prim->set_keep_dims(false); | |||
| const caffe::ReductionParameter &reduce_param = proto.reduction_param(); | |||
| if (reduce_param.has_operation()) { | |||
| if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_MEAN) { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Mean); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Mean); | |||
| } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUM) { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUMSQ) { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum_Square); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Sum_Square); | |||
| } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_ASUM) { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_ASum); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_ASum); | |||
| } else { | |||
| MS_LOG(ERROR) << "nsupported reduce mode: " << reduce_param.operation(); | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| } | |||
| std::vector<int32_t> axes; | |||
| @@ -56,9 +52,9 @@ ops::PrimitiveC *CaffeReduceParser::Parse(const caffe::LayerParameter &proto, co | |||
| axes.push_back(1); | |||
| axes.push_back(0); | |||
| } | |||
| primitive_c->AddAttr("axes", MakeValue(axes)); | |||
| prim->AddAttr("axes", MakeValue(axes)); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser()); | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Reshape(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Reshape failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Reshape>(); | |||
| const caffe::ReshapeParameter &reshapeParam = proto.reshape_param(); | |||
| if (!reshapeParam.has_shape()) { | |||
| @@ -37,9 +33,9 @@ ops::PrimitiveC *CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, c | |||
| for (int i = 0; i < blob_shape.dim_size(); i++) { | |||
| shape.push_back(blob_shape.dim(i)); | |||
| } | |||
| primitive_c->AddAttr("shape", MakeValue(shape)); | |||
| prim->AddAttr("shape", MakeValue(shape)); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeReshapeParser("Reshape", new CaffeReshapeParser()); | |||
| @@ -20,26 +20,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) { | |||
| if (axis < -4 || axis >= 4) { | |||
| MS_LOG(ERROR) << "Scale axis value(" << axis << ") is not correct"; | |||
| return RET_ERROR; | |||
| } | |||
| if (axis == -1) { | |||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| *axis_index = (axis + 4) % 4; | |||
| return RET_OK; | |||
| } | |||
| ops::PrimitiveC *CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::ScaleFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ScaleFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ScaleFusion>(); | |||
| if (weight.blobs_size() + weight.bottom_size() < 2) { | |||
| MS_LOG(ERROR) << "Scale bottom size:" << weight.bottom_size() << ", blobs size:" << weight.blobs_size() | |||
| @@ -58,9 +40,9 @@ ops::PrimitiveC *CaffeScaleParser::Parse(const caffe::LayerParameter &proto, con | |||
| MS_LOG(WARNING) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| } | |||
| primitive_c->set_axis(1); | |||
| prim->set_axis(1); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeScaleParser("Scale", new CaffeScaleParser()); | |||
| @@ -29,8 +29,6 @@ class CaffeScaleParser : public CaffeNodeParser { | |||
| ~CaffeScaleParser() override = default; | |||
| ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override; | |||
| static STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,16 +21,12 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Split(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Split failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Split>(); | |||
| const caffe::SliceParameter &slice_param = proto.slice_param(); | |||
| primitive_c->set_output_num(2); | |||
| prim->set_output_num(2); | |||
| if (!slice_param.slice_point().empty()) { | |||
| primitive_c->set_output_num(slice_param.slice_point_size() + 1); | |||
| prim->set_output_num(slice_param.slice_point_size() + 1); | |||
| std::vector<int64_t> size_splits; | |||
| for (int i = 0; i < slice_param.slice_point_size(); ++i) { | |||
| if (i == 0) { | |||
| @@ -40,16 +36,16 @@ ops::PrimitiveC *CaffeSliceParser::Parse(const caffe::LayerParameter &proto, con | |||
| } | |||
| } | |||
| size_splits.push_back(-1); | |||
| primitive_c->set_size_splits(size_splits); | |||
| prim->set_size_splits(size_splits); | |||
| } | |||
| if (slice_param.has_axis()) { | |||
| primitive_c->set_axis(slice_param.axis()); | |||
| prim->set_axis(slice_param.axis()); | |||
| } else if (slice_param.has_slice_dim()) { | |||
| primitive_c->set_axis(slice_param.slice_dim()); | |||
| prim->set_axis(slice_param.slice_dim()); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeSliceParser("Slice", new CaffeSliceParser()); | |||
| @@ -21,22 +21,18 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::Softmax(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Softmax failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Softmax>(); | |||
| if (proto.has_softmax_param() && proto.softmax_param().has_axis()) { | |||
| if (proto.softmax_param().axis() == -1) { | |||
| MS_LOG(DEBUG) << "axis with -1 may lead to calculation errors when input less than 4 dims."; | |||
| } | |||
| primitive_c->set_axis({proto.softmax_param().axis()}); | |||
| prim->set_axis({proto.softmax_param().axis()}); | |||
| } else { | |||
| primitive_c->set_axis({1}); | |||
| prim->set_axis({1}); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeSoftmaxParser("Softmax", new CaffeSoftmaxParser()); | |||
| @@ -22,11 +22,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *CaffeTileParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) { | |||
| auto primitive_c = new (std::nothrow) ops::TileFusion(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new TileFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::TileFusion>(); | |||
| const caffe::TileParameter &tile_param = proto.tile_param(); | |||
| std::vector<int64_t> dims; | |||
| dims.clear(); | |||
| @@ -35,7 +32,7 @@ ops::PrimitiveC *CaffeTileParser::Parse(const caffe::LayerParameter &proto, cons | |||
| } else { | |||
| dims.push_back(1); | |||
| } | |||
| primitive_c->set_dims(dims); | |||
| prim->set_dims(dims); | |||
| std::vector<int32_t> multiples; | |||
| multiples.clear(); | |||
| @@ -44,9 +41,9 @@ ops::PrimitiveC *CaffeTileParser::Parse(const caffe::LayerParameter &proto, cons | |||
| } else { | |||
| multiples.push_back(1); | |||
| } | |||
| primitive_c->AddAttr("multiples", MakeValue(multiples)); | |||
| prim->AddAttr("multiples", MakeValue(multiples)); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| CaffeNodeRegistrar g_caffeTileParser("Tile", new CaffeTileParser()); | |||
| @@ -25,42 +25,30 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ReLU failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::RELU); | |||
| prim->set_activation_type(mindspore::ActivationType::RELU); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxLeakyReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LeakyRelu failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "alpha") { | |||
| primitive_c->set_alpha(onnx_node_attr.f()); | |||
| prim->set_alpha(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| primitive_c->set_activation_type(mindspore::ActivationType::LEAKY_RELU); | |||
| prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::PReLUFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new PReLU failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::PReLUFusion>(); | |||
| std::vector<onnx::TensorProto> params; | |||
| const auto &input_name = onnx_node.input(1); | |||
| @@ -82,10 +70,10 @@ ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| const auto slope_raw_data = reinterpret_cast<const float *>(slope_data->raw_data().data()); | |||
| const int64_t slope_size = slope_data->raw_data().size() / sizeof(float); | |||
| std::vector<float> slope; | |||
| bool channelShared = false; | |||
| bool channel_shared = false; | |||
| if (slope_size == 1) { | |||
| slope.push_back(*slope_raw_data); | |||
| channelShared = true; | |||
| channel_shared = true; | |||
| } else { | |||
| slope.resize(slope_size); | |||
| if (memcpy_s(slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) { | |||
| @@ -93,54 +81,42 @@ ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| return nullptr; | |||
| } | |||
| } | |||
| primitive_c->set_slope(slope); | |||
| primitive_c->set_channel_shared(channelShared); | |||
| prim->set_slope(slope); | |||
| prim->set_channel_shared(channel_shared); | |||
| } else { | |||
| MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Elu; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Elu failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Elu>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "alpha") { | |||
| primitive_c->set_alpha(onnx_node_attr.f()); | |||
| prim->set_alpha(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Tanh failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::TANH); | |||
| prim->set_activation_type(mindspore::ActivationType::TANH); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Sigmoid failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| primitive_c->set_activation_type(mindspore::ActivationType::SIGMOID); | |||
| prim->set_activation_type(mindspore::ActivationType::SIGMOID); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | |||
| @@ -21,13 +21,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::AdderFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new AdderFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::AdderFusion>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser()); | |||
| @@ -21,22 +21,18 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::ArgMaxFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ArgMax failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ArgMaxFusion>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| primitive_c->set_axis(onnx_node_attr.i()); | |||
| prim->set_axis(onnx_node_attr.i()); | |||
| } else if (attribute_name == "keepdims") { | |||
| primitive_c->set_keep_dims(static_cast<bool>(onnx_node_attr.i())); | |||
| prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i())); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); | |||
| @@ -50,297 +50,160 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::AddFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new AddFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::AddFusion>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::SubFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new SubFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::SubFusion>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::DivFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new DivFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::DivFusion>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::MulFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MulFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::MulFusion>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Equal; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Equal failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Equal>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Less; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Less failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Less>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Greater; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Greater failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Greater>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Floor; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Floor failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Floor>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Abs; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Abs failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Abs>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::ExpFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ExpFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ExpFusion>(); | |||
| primitive_c->set_base(-1.0); | |||
| primitive_c->set_scale(1.0); | |||
| primitive_c->set_shift(0.0); | |||
| prim->set_base(-1.0); | |||
| prim->set_scale(1.0); | |||
| prim->set_shift(0.0); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Cos; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Cos failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Cos>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Ceil; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Ceil failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Ceil>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Log; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Log failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Log>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Atan; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Atan failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Atan>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Asin; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Asin failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Asin>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::LogicalAnd; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LogicalAnd failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::LogicalAnd>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::LogicalOr; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LogicalOr failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::LogicalOr>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::LogicalNot; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LogicalNot failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::LogicalNot>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Neg; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Neg failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Neg>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Round; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Round failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Round>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Sin; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new sin failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Sin>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Tan; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Tan failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Tan>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Sqrt; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Sqrt failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Sqrt>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::PowFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new PowFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::PowFusion>(); | |||
| primitive_c->set_scale(1.0); | |||
| primitive_c->set_shift(0.0); | |||
| prim->set_scale(1.0); | |||
| prim->set_shift(0.0); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Minimum; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Minimum failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Minimum>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Maximum; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Maximum failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Maximum>(); | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Eltwise; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Eltwise failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Eltwise>(); | |||
| if (onnx_node.op_type() == "Sum") { | |||
| primitive_c->set_mode(mindspore::EltwiseMode::SUM); | |||
| prim->set_mode(mindspore::EltwiseMode::SUM); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported Eltwise type"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxReciprocalParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Reciprocal; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Reciprocal failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Reciprocal>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); | |||
| @@ -21,21 +21,17 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::FusedBatchNorm; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new FusedBatchNorm failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::FusedBatchNorm>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "epsilon") { | |||
| primitive_c->set_epsilon(onnx_node_attr.f()); | |||
| prim->set_epsilon(onnx_node_attr.f()); | |||
| } else if (onnx_node_attr.name() == "momentum") { | |||
| primitive_c->set_momentum(onnx_node_attr.f()); | |||
| prim->set_momentum(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); | |||
| @@ -21,13 +21,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::BiasAdd; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new BiasAdd failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::BiasAdd>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); | |||
| @@ -23,11 +23,7 @@ namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Cast; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Cast failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Cast>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| @@ -36,11 +32,11 @@ ops::PrimitiveC *OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| if (dst_type == kNumberTypeInt64) { | |||
| dst_type = kNumberTypeInt32; | |||
| } | |||
| primitive_c->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type))); | |||
| prim->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type))); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); | |||
| @@ -21,24 +21,20 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Clip; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Clip failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Clip>(); | |||
| primitive_c->set_min(-1); | |||
| primitive_c->set_max(-1); | |||
| prim->set_min(-1); | |||
| prim->set_max(-1); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "max") { | |||
| primitive_c->set_max(onnx_node_attr.f()); | |||
| prim->set_max(onnx_node_attr.f()); | |||
| } else if (attribute_name == "min") { | |||
| primitive_c->set_min(onnx_node_attr.f()); | |||
| prim->set_min(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Concat; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Concat failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Concat>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| primitive_c->set_axis(onnx_node_attr.i()); | |||
| prim->set_axis(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); | |||
| @@ -24,11 +24,7 @@ namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::ConstantOfShape; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ConstantOfShape failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ConstantOfShape>(); | |||
| int data_type = 0; | |||
| std::vector<float> values; | |||
| @@ -61,10 +57,10 @@ ops::PrimitiveC *OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_g | |||
| if (values.empty()) { | |||
| values = {0}; | |||
| } | |||
| primitive_c->set_value(values); | |||
| primitive_c->set_data_type((int64_t)data_type); | |||
| prim->set_value(values); | |||
| prim->set_data_type((int64_t)data_type); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser()); | |||
| @@ -24,7 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *primitive_c) { | |||
| STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *prim) { | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||
| @@ -48,16 +48,12 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t | |||
| MS_LOG(ERROR) << "get value failed."; | |||
| return RET_ERROR; | |||
| } | |||
| primitive_c->set_attr("const_data", param_value); | |||
| prim->set_attr("const_data", param_value); | |||
| return RET_OK; | |||
| } | |||
| ops::PrimitiveC *OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Constant; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Constant failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Constant>(); | |||
| for (const auto &attr : onnx_node.attribute()) { | |||
| if (attr.name() == "sparse_value") { | |||
| @@ -66,18 +62,16 @@ ops::PrimitiveC *OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, c | |||
| } | |||
| if (attr.name() == "value") { | |||
| const auto &const_tensor = attr.t(); | |||
| if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) { | |||
| if (AddDataInfoAttr(const_tensor, prim.get()) != RET_OK) { | |||
| MS_LOG(ERROR) << "add basic attr failed."; | |||
| delete primitive_c; | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented"; | |||
| delete primitive_c; | |||
| return nullptr; | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); | |||
| @@ -29,7 +29,7 @@ class OnnxConstantParser : public OnnxNodeParser { | |||
| ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *primitive_c); | |||
| STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *prim); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,22 +20,17 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/fusion/depthwise_conv2d_fusion.h" | |||
| namespace mindspore::lite { | |||
| ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Conv2DFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Conv2DFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Conv2DFusion>(); | |||
| primitive_c->set_pad({0, 0, 0, 0}); | |||
| prim->set_pad({0, 0, 0, 0}); | |||
| mindspore::Format format = mindspore::Format::NCHW; | |||
| mindspore::PadMode padMode = mindspore::PadMode::PAD; | |||
| mindspore::PadMode pad_mode = mindspore::PadMode::PAD; | |||
| int64_t channelOut = 1; | |||
| int64_t channelIn = 1; | |||
| int64_t channel_out = 1; | |||
| int64_t channel_in = 1; | |||
| int64_t group = 1; | |||
| std::vector<int64_t> kernels; | |||
| std::vector<int64_t> strides; | |||
| @@ -59,7 +54,7 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| } | |||
| kernels.push_back(onnx_node_attr.ints(0)); | |||
| kernels.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_kernel_size(kernels); | |||
| prim->set_kernel_size(kernels); | |||
| } else if (onnx_node_attr.name() == "kernel_shape") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| @@ -67,14 +62,14 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| } | |||
| kernels.push_back(onnx_node_attr.ints(0)); | |||
| kernels.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_kernel_size(kernels); | |||
| prim->set_kernel_size(kernels); | |||
| } else if (onnx_node_attr.name() == "auto_pad") { | |||
| if (onnx_node_attr.s() == "SAME_UPPER") { | |||
| padMode = mindspore::PadMode::SAME; | |||
| pad_mode = mindspore::PadMode::SAME; | |||
| } else if (onnx_node_attr.s() == "VALID") { | |||
| padMode = mindspore::PadMode::VALID; | |||
| pad_mode = mindspore::PadMode::VALID; | |||
| } else if (onnx_node_attr.s() == "NOTSET") { | |||
| padMode = mindspore::PadMode::PAD; | |||
| pad_mode = mindspore::PadMode::PAD; | |||
| } else if (onnx_node_attr.s() == "SAME_LOWER") { | |||
| MS_LOG(ERROR) << "unsupported padMode"; | |||
| return nullptr; | |||
| @@ -88,7 +83,7 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| pads.push_back(onnx_node_attr.ints(2)); | |||
| pads.push_back(onnx_node_attr.ints(1)); | |||
| pads.push_back(onnx_node_attr.ints(3)); | |||
| primitive_c->set_pad_list(pads); | |||
| prim->set_pad_list(pads); | |||
| } else if (onnx_node_attr.name() == "strides") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| @@ -96,7 +91,7 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| } | |||
| strides.push_back(onnx_node_attr.ints(0)); | |||
| strides.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_stride(strides); | |||
| prim->set_stride(strides); | |||
| } else if (onnx_node_attr.name() == "order") { | |||
| if (onnx_node_attr.s() == "NHWC") { | |||
| format = mindspore::Format::NHWC; | |||
| @@ -109,18 +104,18 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| if (dilation.empty()) { | |||
| dilation = {1, 1}; | |||
| } | |||
| primitive_c->set_dilation(dilation); | |||
| prim->set_dilation(dilation); | |||
| if (pads.empty()) { | |||
| pads = {0, 0, 0, 0}; | |||
| } | |||
| primitive_c->set_pad_list(pads); | |||
| prim->set_pad_list(pads); | |||
| primitive_c->set_format(format); | |||
| primitive_c->set_pad_mode(padMode); | |||
| primitive_c->set_group(group); | |||
| prim->set_format(format); | |||
| prim->set_pad_mode(pad_mode); | |||
| prim->set_group(group); | |||
| // get channelOut and channelIn | |||
| // get channel_out and channel_in | |||
| const auto &onnx_conv_weight = onnx_node.input(1); | |||
| if (onnx_node.op_type() == "Conv") { | |||
| auto node_iter = | |||
| @@ -135,8 +130,8 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| for (int i = 0; i < size; ++i) { | |||
| weight_shape.emplace_back((*node_iter).dims(i)); | |||
| } | |||
| channelOut = weight_shape[0]; | |||
| channelIn = weight_shape[1] * group; | |||
| channel_out = weight_shape[0]; | |||
| channel_in = weight_shape[1] * group; | |||
| } | |||
| } else { | |||
| auto node_iter = | |||
| @@ -156,23 +151,24 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| } | |||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } | |||
| channelOut = dims.at(0); | |||
| channelIn = dims.at(3) * group; | |||
| channel_out = dims.at(0); | |||
| channel_in = dims.at(3) * group; | |||
| } | |||
| primitive_c->set_in_channel(channelIn); | |||
| primitive_c->set_out_channel(channelOut); | |||
| prim->set_in_channel(channel_in); | |||
| prim->set_out_channel(channel_out); | |||
| // parse activationType | |||
| if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::RELU); | |||
| prim->set_activation_type(mindspore::ActivationType::RELU); | |||
| } else { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::NO_ACTIVATION); | |||
| prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION); | |||
| } | |||
| if (group == channelIn && channelIn == channelOut) { | |||
| primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| if (group == channel_in && channel_in == channel_out) { | |||
| prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); | |||
| @@ -23,15 +23,12 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Conv2dTransposeFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Conv2dTransposeFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Conv2dTransposeFusion>(); | |||
| primitive_c->set_pad({0, 0, 0, 0}); | |||
| prim->set_pad({0, 0, 0, 0}); | |||
| mindspore::Format format = mindspore::Format::NCHW; | |||
| mindspore::PadMode padMode = mindspore::PadMode::PAD; | |||
| mindspore::PadMode pad_mode = mindspore::PadMode::PAD; | |||
| int64_t group = 1; | |||
| std::vector<int64_t> kernel; | |||
| std::vector<int64_t> dilate; | |||
| @@ -47,7 +44,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| dilate.push_back(onnx_node_attr.ints(0)); | |||
| dilate.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_dilation(dilate); | |||
| prim->set_dilation(dilate); | |||
| } else if (onnx_node_attr.name() == "kernels") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| @@ -55,7 +52,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| kernel.push_back(onnx_node_attr.ints(0)); | |||
| kernel.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_kernel_size(kernel); | |||
| prim->set_kernel_size(kernel); | |||
| } else if (onnx_node_attr.name() == "kernel_shape") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| @@ -63,9 +60,9 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| kernel.push_back(onnx_node_attr.ints(0)); | |||
| kernel.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_kernel_size(kernel); | |||
| prim->set_kernel_size(kernel); | |||
| } else if (onnx_node_attr.name() == "auto_pad") { | |||
| padMode = GetOnnxPadMode(onnx_node_attr); | |||
| pad_mode = GetOnnxPadMode(onnx_node_attr); | |||
| } else if (onnx_node_attr.name() == "pads") { | |||
| if (onnx_node_attr.ints().size() != 4) { | |||
| MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | |||
| @@ -75,7 +72,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| pads.push_back(onnx_node_attr.ints(2)); | |||
| pads.push_back(onnx_node_attr.ints(1)); | |||
| pads.push_back(onnx_node_attr.ints(3)); | |||
| primitive_c->set_pad_list(pads); | |||
| prim->set_pad_list(pads); | |||
| } else if (onnx_node_attr.name() == "strides") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| @@ -83,7 +80,7 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| stride.push_back(onnx_node_attr.ints(0)); | |||
| stride.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_stride(stride); | |||
| prim->set_stride(stride); | |||
| } else if (onnx_node_attr.name() == "order") { | |||
| if (onnx_node_attr.s() == "NHWC") { | |||
| format = mindspore::Format::NHWC; | |||
| @@ -96,9 +93,9 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| return nullptr; | |||
| } | |||
| } | |||
| primitive_c->set_format(format); | |||
| primitive_c->set_group(group); | |||
| primitive_c->set_pad_mode(padMode); | |||
| prim->set_format(format); | |||
| prim->set_group(group); | |||
| prim->set_pad_mode(pad_mode); | |||
| const auto &onnx_conv_weight = onnx_node.input(1); | |||
| auto node_iter = | |||
| @@ -118,12 +115,14 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_in_channel(weight_shape[0]); | |||
| primitive_c->set_out_channel(weight_shape[1] * group); | |||
| prim->set_in_channel(weight_shape[0]); | |||
| prim->set_out_channel(weight_shape[1] * group); | |||
| if (group != 1 && weight_shape[1] == 1) { | |||
| primitive_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::DepthToSpace; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new DepthToSpace failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::DepthToSpace>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "blocksize") { | |||
| primitive_c->set_block_size(onnx_node_attr.i()); | |||
| prim->set_block_size(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Dropout; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Dropout failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Dropout>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "ratio") { | |||
| primitive_c->set_keep_prob(onnx_node_attr.f()); | |||
| prim->set_keep_prob(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::BroadcastTo; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new BroadcastTo failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::BroadcastTo>(); | |||
| std::vector<int64_t> dst_shape; | |||
| const auto &onnx_expand_power = onnx_node.input(1); | |||
| @@ -46,9 +42,9 @@ ops::PrimitiveC *OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| } | |||
| } | |||
| primitive_c->set_shape(dst_shape); | |||
| prim->set_shape(dst_shape); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); | |||
| @@ -21,13 +21,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Flatten; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Flatten failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Flatten>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Gather; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Gather failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Gather>(); | |||
| int32_t axis = 0; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -34,9 +30,9 @@ ops::PrimitiveC *OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| axis = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| primitive_c->AddAttr("axis", MakeValue(axis)); | |||
| prim->AddAttr("axis", MakeValue(axis)); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::MakeTuple; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MakeTuple failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::MakeTuple>(); | |||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); | |||
| if (node_parser == nullptr) { | |||
| @@ -34,7 +30,7 @@ ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| return nullptr; | |||
| } | |||
| auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node); | |||
| primitive_c->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive)); | |||
| prim->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive)); | |||
| node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); | |||
| if (node_parser == nullptr) { | |||
| @@ -42,9 +38,9 @@ ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| return nullptr; | |||
| } | |||
| auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node); | |||
| primitive_c->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive)); | |||
| prim->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive)); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); | |||
| @@ -24,14 +24,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, | |||
| ops::PrimitiveC *primitive_c, | |||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, | |||
| const std::vector<int> &shape) { | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); | |||
| @@ -46,6 +42,7 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodePr | |||
| } | |||
| if (iter->ints().data() == nullptr) { | |||
| MS_LOG(ERROR) << "origin ints data in onnx is nullptr"; | |||
| delete[] param_data; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) { | |||
| @@ -57,18 +54,14 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodePr | |||
| param_value->set_format(schema::Format_NUM_OF_FORMAT); | |||
| param_value->set_tensor_type(kNumberTypeInt64); | |||
| param_value->SetTensorData(param_data, data_size); | |||
| primitive_c->set_attr("const_data", param_value); | |||
| prim->set_attr("const_data", param_value); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, | |||
| ops::PrimitiveC *primitive_c, | |||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, | |||
| const std::vector<int> &shape) { | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()); | |||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); | |||
| @@ -89,16 +82,12 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto | |||
| param_value->set_format(schema::Format_NUM_OF_FORMAT); | |||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||
| param_value->SetTensorData(param_data, data_count); | |||
| primitive_c->set_attr("const_data", param_value); | |||
| prim->set_attr("const_data", param_value); | |||
| return RET_OK; | |||
| } | |||
| ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Constant; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Constant failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Constant>(); | |||
| std::vector<int64_t> shape_vector; | |||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||
| @@ -110,18 +99,18 @@ ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_g | |||
| std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||
| [](const int64_t &val) { return static_cast<int32_t>(val); }); | |||
| if (onnx_node.op_type() == "Int8GivenIntTensorFill") { | |||
| if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) { | |||
| if (ParseInt8GivenIntTensorFill(onnx_node, prim.get(), shape) != RET_OK) { | |||
| MS_LOG(ERROR) << "given tensor fill parse failed."; | |||
| return nullptr; | |||
| } | |||
| } else if (onnx_node.op_type() == "Int8GivenTensorFill") { | |||
| if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) { | |||
| if (ParseInt8GivenTensorFill(onnx_node, prim.get(), shape) != RET_OK) { | |||
| MS_LOG(ERROR) << "given tensor fill parse failed."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser()); | |||
| @@ -30,9 +30,9 @@ class OnnxGivenTensorFillParser : public OnnxNodeParser { | |||
| ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c, | |||
| STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, | |||
| const std::vector<int> &shape); | |||
| STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c, | |||
| STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, | |||
| const std::vector<int> &shape); | |||
| }; | |||
| } // namespace lite | |||
| @@ -21,14 +21,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Identity; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Identity failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Identity>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,22 +21,18 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::LayerNormFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LayerNormFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::LayerNormFusion>(); | |||
| primitive_c->set_elementwise_affine(true); | |||
| prim->set_elementwise_affine(true); | |||
| if (!onnx_node.attribute().empty()) { | |||
| auto onnx_node_attr = onnx_node.attribute().at(0); | |||
| if (onnx_node_attr.name() == "epsilon") { | |||
| primitive_c->set_epsilon(onnx_node_attr.f()); | |||
| prim->set_epsilon(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser()); | |||
| @@ -21,22 +21,18 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::LpNormalization; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LpNormalization failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::LpNormalization>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| primitive_c->set_axis(onnx_node_attr.i()); | |||
| prim->set_axis(onnx_node_attr.i()); | |||
| } else if (attribute_name == "p") { | |||
| primitive_c->set_p(onnx_node_attr.i()); | |||
| prim->set_p(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser()); | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Lrn; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LRN failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Lrn>(); | |||
| int64_t size = 0; | |||
| float alpha = 0; | |||
| @@ -34,12 +30,12 @@ ops::PrimitiveC *OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| if (attribute_name == "alpha") { | |||
| alpha = onnx_node_attr.f(); | |||
| } else if (attribute_name == "beta") { | |||
| primitive_c->set_beta(onnx_node_attr.f()); | |||
| prim->set_beta(onnx_node_attr.f()); | |||
| } else if (attribute_name == "bias") { | |||
| primitive_c->set_bias(onnx_node_attr.f()); | |||
| prim->set_bias(onnx_node_attr.f()); | |||
| } else if (attribute_name == "size") { | |||
| size = onnx_node_attr.i(); | |||
| primitive_c->set_depth_radius(size / 2); | |||
| prim->set_depth_radius(size / 2); | |||
| } | |||
| } | |||
| @@ -48,9 +44,9 @@ ops::PrimitiveC *OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| return nullptr; | |||
| } | |||
| alpha /= size; | |||
| primitive_c->set_alpha(alpha); | |||
| prim->set_alpha(alpha); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); | |||
| @@ -21,32 +21,28 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::LSTM; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LSTM failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::LSTM>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "direction") { | |||
| const auto &direction = onnx_node_attr.s(); | |||
| bool bidirectional = direction == "bidirectional"; | |||
| primitive_c->set_bidirectional(bidirectional); | |||
| prim->set_bidirectional(bidirectional); | |||
| if (bidirectional) { | |||
| primitive_c->set_num_directions(2); | |||
| prim->set_num_directions(2); | |||
| } else { | |||
| primitive_c->set_num_directions(1); | |||
| prim->set_num_directions(1); | |||
| } | |||
| } else if (onnx_node_attr.name() == "hidden_size") { | |||
| primitive_c->set_hidden_size(onnx_node_attr.i()); | |||
| prim->set_hidden_size(onnx_node_attr.i()); | |||
| } else if (onnx_node_attr.name() == "clip") { | |||
| primitive_c->set_dropout(onnx_node_attr.f()); | |||
| prim->set_dropout(onnx_node_attr.f()); | |||
| } else if (onnx_node_attr.name() == "activations") { | |||
| primitive_c->set_has_bias(true); | |||
| prim->set_has_bias(true); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::MatMul; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MatMul failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::MatMul>(); | |||
| float alpha = 1.0f; | |||
| float beta = 1.0f; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "transA") { | |||
| primitive_c->set_transpose_a(static_cast<bool>(onnx_node_attr.i())); | |||
| prim->set_transpose_a(static_cast<bool>(onnx_node_attr.i())); | |||
| } else if (attribute_name == "transB") { | |||
| primitive_c->set_transpose_b(static_cast<bool>(onnx_node_attr.i())); | |||
| prim->set_transpose_b(static_cast<bool>(onnx_node_attr.i())); | |||
| } else if (attribute_name == "alpha") { | |||
| alpha = onnx_node_attr.f(); | |||
| } else if (attribute_name == "beta") { | |||
| @@ -46,7 +42,7 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | |||
| @@ -48,10 +48,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st | |||
| const QuantType &quant_type) { | |||
| NoSupportOp::GetInstance()->SetFmkType("ONNX"); | |||
| func_graph_ptr_ = std::make_shared<FuncGraph>(); | |||
| if (func_graph_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "funcgraph is nullptr."; | |||
| return nullptr; | |||
| } | |||
| auto status = InitOriginModel(model_file); | |||
| if (RET_OK != status) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| @@ -164,25 +161,24 @@ STATUS OnnxModelParser::ConvertNodes() { | |||
| if (status != RET_OK) { | |||
| continue; | |||
| } | |||
| auto primitive_c = node_parser->Parse(onnx_graph_, onnx_node); | |||
| auto prim = node_parser->Parse(onnx_graph_, onnx_node); | |||
| MS_LOG(INFO) << "parse op:" << onnx_node.op_type(); | |||
| if (primitive_c == nullptr) { | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||
| status = RET_ERROR; | |||
| continue; | |||
| } | |||
| status = ConvertOpQuantParams(onnx_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| if (ConvertOpQuantParams(onnx_node, prim) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; | |||
| continue; | |||
| } | |||
| if (IsSpecialOnnxNode(onnx_node)) { | |||
| auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); | |||
| auto status_node = ConvertSpecialOnnxNode(onnx_node, prim); | |||
| status = status == RET_OK ? status_node : status; | |||
| continue; | |||
| } | |||
| // build CNode | |||
| status = BuildCNode(onnx_node, primitive_c); | |||
| status = BuildCNode(onnx_node, prim); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed."; | |||
| } | |||
| @@ -195,10 +191,7 @@ STATUS OnnxModelParser::ConvertGraphOutputs() { | |||
| if (onnx_graph_.output_size() > 1) { | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>(); | |||
| if (make_tuple_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "new return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| for (const auto &graph_out : onnx_graph_.output()) { | |||
| if (nodes_.find(graph_out.name()) == nodes_.end()) { | |||
| MS_LOG(ERROR) << "graph output get failed."; | |||
| @@ -236,19 +229,15 @@ STATUS OnnxModelParser::ConvertGraphOutputs() { | |||
| STATUS OnnxModelParser::BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs) { | |||
| auto returnPrim = std::make_shared<ops::Return>(); | |||
| if (returnPrim == nullptr) { | |||
| MS_LOG(ERROR) << "new return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto returnCnode = func_graph_ptr_->NewCNode(returnPrim, return_inputs); | |||
| returnCnode->set_fullname_with_scope("return"); | |||
| func_graph_ptr_->set_return(returnCnode); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr."; | |||
| STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) { | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs; | |||
| @@ -263,7 +252,7 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, ops::Primit | |||
| op_inputs.push_back(nodes_[input_name]); | |||
| } | |||
| } | |||
| auto new_cnode = func_graph_ptr_->NewCNode(std::shared_ptr<ops::PrimitiveC>(primitive_c), op_inputs); | |||
| auto new_cnode = func_graph_ptr_->NewCNode(std::shared_ptr<ops::PrimitiveC>(prim), op_inputs); | |||
| new_cnode->set_fullname_with_scope(onnx_node.op_type() + "_" + onnx_node.output(0)); | |||
| auto status = BuildOpOutputs(onnx_node, new_cnode); | |||
| return status; | |||
| @@ -287,10 +276,6 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const C | |||
| auto type_ptr = TypeIdToType(kTypeUnknown); | |||
| abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||
| auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | |||
| if (tuple_get_item_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "new return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||
| auto get_item_value = NewValueNode(MakeValue<int>(op_idx)); | |||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value}; | |||
| @@ -304,9 +289,9 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const C | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | |||
| STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) { | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is null, get quant params failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto status = ParseQuantParam(onnx_node); | |||
| @@ -337,7 +322,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o | |||
| } | |||
| quant_params_holder->AddOutputQuantParam(quant_params); | |||
| } | |||
| primitive_c->AddAttr("quant_params", quant_params_holder); | |||
| prim->AddAttr("quant_params", quant_params_holder); | |||
| return RET_OK; | |||
| } | |||
| @@ -462,8 +447,8 @@ STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, Qua | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr) { | |||
| STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) { | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "imitive_c is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| @@ -472,30 +457,30 @@ STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, | |||
| MS_LOG(ERROR) << "loop hasn't supported."; | |||
| return RET_NOT_FIND_OP; | |||
| } else if (onnx_node.op_type() == "Gemm") { | |||
| status = ConvertOnnxGemmNode(onnx_node, primitive_c); | |||
| status = ConvertOnnxGemmNode(onnx_node, prim); | |||
| } else { | |||
| MS_LOG(ERROR) << "the node is not special node."; | |||
| status = RET_ERROR; | |||
| } | |||
| delete primitive_c; | |||
| delete prim; | |||
| return status; | |||
| } | |||
| STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) { | |||
| STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim) { | |||
| if (onnx_node.op_type() != "Gemm") { | |||
| MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr."; | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto status = BuildCNodeForGemm(onnx_node, primitive_c, "MatMul"); | |||
| auto status = BuildCNodeForGemm(onnx_node, prim, "MatMul"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert gemm node failed."; | |||
| return status; | |||
| } | |||
| status = BuildCNodeForGemm(onnx_node, primitive_c, "BiasAdd"); | |||
| status = BuildCNodeForGemm(onnx_node, prim, "BiasAdd"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert gemm node failed."; | |||
| return status; | |||
| @@ -503,14 +488,14 @@ STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, op | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c, | |||
| STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, | |||
| const std::string &name) { | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr."; | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value = primitive_c->GetAttr(name); | |||
| primitive_c->EraseAttr(name); | |||
| auto value = prim->GetAttr(name); | |||
| prim->EraseAttr(name); | |||
| if (value == nullptr) { | |||
| MS_LOG(ERROR) << "op parse failed."; | |||
| return RET_NULL_PTR; | |||
| @@ -524,7 +509,7 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops: | |||
| std::vector<int64_t> shape_vector; | |||
| std::vector<AnfNodePtr> op_inputs; | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | |||
| auto quant_params_holder_origin = prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | |||
| if (name == "MatMul") { | |||
| for (int i = 0; i < 2; ++i) { | |||
| if (nodes_.find(onnx_node.input(i)) == nodes_.end()) { | |||
| @@ -55,12 +55,12 @@ class OnnxModelParser : public ModelParser { | |||
| STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs); | |||
| STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor); | |||
| STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type); | |||
| STATUS BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | |||
| STATUS BuildCNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim); | |||
| STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode); | |||
| STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | |||
| STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | |||
| STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c, const std::string &name); | |||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | |||
| STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim); | |||
| STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim); | |||
| STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, const std::string &name); | |||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim); | |||
| STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | |||
| STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| @@ -22,22 +22,18 @@ namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::NonMaxSuppression; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new NonMaxSuppression failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::NonMaxSuppression>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "center_point_box") { | |||
| if (onnx_node_attr.has_i()) { | |||
| primitive_c->set_center_point_box(onnx_node_attr.i()); | |||
| prim->set_center_point_box(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::OneHot; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new OneHot failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::OneHot>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| primitive_c->set_axis(onnx_node_attr.i()); | |||
| prim->set_axis(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); | |||
| @@ -22,13 +22,9 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::PadFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new PadFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::PadFusion>(); | |||
| mindspore::PaddingMode paddingMode; | |||
| mindspore::PaddingMode padding_mode; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "pads") { | |||
| @@ -40,28 +36,28 @@ ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| paddings[i][0] = static_cast<int64_t>(onnx_node_attr.ints(i)); | |||
| paddings[i][1] = static_cast<int64_t>(onnx_node_attr.ints(i + size / 2)); | |||
| } | |||
| primitive_c->set_paddings(paddings); | |||
| prim->set_paddings(paddings); | |||
| std::vector<std::vector<int32_t>> pads(size / 2, std::vector<int32_t>(2, 0)); | |||
| for (int i = 0; i < size / 2; i++) { | |||
| pads[i][0] = static_cast<int32_t>(onnx_node_attr.ints(i)); | |||
| pads[i][1] = static_cast<int32_t>(onnx_node_attr.ints(i + size / 2)); | |||
| } | |||
| primitive_c->AddAttr("pads", MakeValue(pads)); | |||
| prim->AddAttr("pads", MakeValue(pads)); | |||
| } else if (attribute_name == "mode") { | |||
| const auto &mode = onnx_node_attr.s(); | |||
| if (mode == "constant") { | |||
| paddingMode = mindspore::PaddingMode::CONSTANT; | |||
| padding_mode = mindspore::PaddingMode::CONSTANT; | |||
| } else if (mode == "reflect") { | |||
| paddingMode = mindspore::PaddingMode::REFLECT; | |||
| padding_mode = mindspore::PaddingMode::REFLECT; | |||
| } else if (mode == "edge") { | |||
| paddingMode = mindspore::PaddingMode::SYMMETRIC; | |||
| padding_mode = mindspore::PaddingMode::SYMMETRIC; | |||
| } | |||
| primitive_c->set_padding_mode(paddingMode); | |||
| prim->set_padding_mode(padding_mode); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); | |||
| @@ -23,14 +23,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::AvgPoolFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new AvgPoolFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::AvgPoolFusion>(); | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| primitive_c->set_pad_mode(mindspore::PadMode::PAD); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| prim->set_pad_mode(mindspore::PadMode::PAD); | |||
| mindspore::RoundMode roundMode = mindspore::RoundMode::FLOOR; | |||
| std::vector<int64_t> kernels; | |||
| std::vector<int64_t> strides; | |||
| @@ -41,7 +37,7 @@ ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| if (onnx_node_attr.ints_size() == 2) { | |||
| kernels.push_back(onnx_node_attr.ints(0)); | |||
| kernels.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_kernel_size(kernels); | |||
| prim->set_kernel_size(kernels); | |||
| } | |||
| } | |||
| if (attribute_name == "strides") { | |||
| @@ -52,7 +48,7 @@ ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| } | |||
| if (attribute_name == "auto_pad") { | |||
| if (onnx_node_attr.s() == "SAME_UPPER") { | |||
| primitive_c->set_pad_mode(mindspore::PadMode::SAME); | |||
| prim->set_pad_mode(mindspore::PadMode::SAME); | |||
| } else if (onnx_node_attr.s() == "SAME_LOWER") { | |||
| MS_LOG(ERROR) << "PadMode_SAME_LOWER is not supported now"; | |||
| return nullptr; | |||
| @@ -78,34 +74,30 @@ ops::PrimitiveC *OnnxAvgPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| return nullptr; | |||
| } | |||
| } | |||
| primitive_c->set_round_mode(roundMode); | |||
| prim->set_round_mode(roundMode); | |||
| if (strides.empty()) { | |||
| strides.push_back(1); | |||
| strides.push_back(1); | |||
| } | |||
| primitive_c->set_strides(strides); | |||
| prim->set_strides(strides); | |||
| if (pads.empty()) { | |||
| pads = {0, 0, 0, 0}; | |||
| } | |||
| primitive_c->set_pad(pads); | |||
| prim->set_pad(pads); | |||
| if (onnx_node.op_type() == "GlobalAveragePool") { | |||
| primitive_c->set_global(true); | |||
| prim->set_global(true); | |||
| } else { | |||
| primitive_c->set_global(false); | |||
| prim->set_global(false); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::MaxPoolFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MaxPoolFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::MaxPoolFusion>(); | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| mindspore::RoundMode roundMode = mindspore::RoundMode::FLOOR; | |||
| std::vector<int64_t> kernels; | |||
| std::vector<int64_t> strides; | |||
| @@ -116,7 +108,7 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| if (onnx_node_attr.ints_size() == 2) { | |||
| kernels.push_back(onnx_node_attr.ints(0)); | |||
| kernels.push_back(onnx_node_attr.ints(1)); | |||
| primitive_c->set_kernel_size(kernels); | |||
| prim->set_kernel_size(kernels); | |||
| } | |||
| } | |||
| if (attribute_name == "strides") { | |||
| @@ -127,7 +119,7 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| } | |||
| if (attribute_name == "auto_pad") { | |||
| if (onnx_node_attr.s() == "SAME_UPPER") { | |||
| primitive_c->set_pad_mode(mindspore::PadMode::SAME); | |||
| prim->set_pad_mode(mindspore::PadMode::SAME); | |||
| } else if (onnx_node_attr.s() == "SAME_LOWER") { | |||
| MS_LOG(ERROR) << "PadMode_SAME_LOWER is not supported now"; | |||
| return nullptr; | |||
| @@ -135,7 +127,7 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| } | |||
| if (attribute_name == "pads") { | |||
| if (onnx_node_attr.ints_size() == 4) { | |||
| primitive_c->set_pad_mode(mindspore::PadMode::PAD); | |||
| prim->set_pad_mode(mindspore::PadMode::PAD); | |||
| pads.push_back(onnx_node_attr.ints(0)); | |||
| pads.push_back(onnx_node_attr.ints(2)); | |||
| pads.push_back(onnx_node_attr.ints(1)); | |||
| @@ -154,22 +146,22 @@ ops::PrimitiveC *OnnxMaxPoolParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| return nullptr; | |||
| } | |||
| } | |||
| primitive_c->set_round_mode(roundMode); | |||
| prim->set_round_mode(roundMode); | |||
| if (pads.empty()) { | |||
| pads = {0, 0, 0, 0}; | |||
| } | |||
| primitive_c->set_pad(pads); | |||
| prim->set_pad(pads); | |||
| if (strides.empty()) { | |||
| strides.push_back(1); | |||
| strides.push_back(1); | |||
| } | |||
| primitive_c->set_strides(strides); | |||
| prim->set_strides(strides); | |||
| primitive_c->set_global(onnx_node.op_type() == "GlobalMaxPool"); | |||
| prim->set_global(onnx_node.op_type() == "GlobalMaxPool"); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxAveragePoolParser("AveragePool", new OnnxAvgPoolParser()); | |||
| @@ -21,24 +21,20 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::QuantDTypeCast; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new QuantDTypeCast failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::QuantDTypeCast>(); | |||
| if (onnx_node.op_type() == "Int8Quantize") { | |||
| primitive_c->set_src_t(kNumberTypeFloat32); | |||
| primitive_c->set_dst_t(kNumberTypeUInt8); | |||
| prim->set_src_t(kNumberTypeFloat32); | |||
| prim->set_dst_t(kNumberTypeUInt8); | |||
| } else if (onnx_node.op_type() == "Int8Dequantize") { | |||
| primitive_c->set_src_t(kNumberTypeUInt8); | |||
| primitive_c->set_dst_t(kNumberTypeFloat32); | |||
| prim->set_src_t(kNumberTypeUInt8); | |||
| prim->set_dst_t(kNumberTypeFloat32); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); | |||
| @@ -21,15 +21,11 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Range; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Range failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Range>(); | |||
| primitive_c->set_d_type(0); | |||
| prim->set_d_type(0); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser()); | |||
| @@ -22,13 +22,9 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::ReduceFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ReduceFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ReduceFusion>(); | |||
| primitive_c->set_keep_dims(true); | |||
| prim->set_keep_dims(true); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axes") { | |||
| @@ -37,30 +33,30 @@ ops::PrimitiveC *OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| for (int i = 0; i < size; ++i) { | |||
| axes.push_back(onnx_node_attr.ints(i)); | |||
| } | |||
| primitive_c->AddAttr("axes", MakeValue(axes)); | |||
| prim->AddAttr("axes", MakeValue(axes)); | |||
| } else if (attribute_name == "keepdims") { | |||
| primitive_c->set_keep_dims(static_cast<bool>(onnx_node_attr.i())); | |||
| prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i())); | |||
| } | |||
| } | |||
| const auto &type = onnx_node.op_type(); | |||
| if (type == "ReduceMean") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Mean); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Mean); | |||
| } else if (type == "ReduceMax") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Max); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Max); | |||
| } else if (type == "ReduceMin") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Min); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Min); | |||
| } else if (type == "ReduceSum") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| } else if (type == "ReduceProd") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Prod); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Prod); | |||
| } else if (type == "ReduceSumSquare") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum_Square); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Sum_Square); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported reduce type: " << type; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Reshape; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Reshape failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Reshape>(); | |||
| std::vector<int32_t> shape; | |||
| shape.clear(); | |||
| @@ -37,12 +33,12 @@ ops::PrimitiveC *OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||
| shape.push_back(static_cast<int>(onnx_node_attr.ints(i))); | |||
| } | |||
| primitive_c->AddAttr("shape", MakeValue(shape)); | |||
| prim->AddAttr("shape", MakeValue(shape)); | |||
| } | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); | |||
| @@ -34,14 +34,11 @@ ops::PrimitiveC *OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| // use bilinear method | |||
| auto primitive_c = new (std::nothrow) ops::Resize; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Resize failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Resize>(); | |||
| prim->set_format(mindspore::Format::NCHW); | |||
| prim->set_nearest_mode(mindspore::NearestMode::ROUND_HALF_DOWN); | |||
| primitive_c->set_format(mindspore::Format::NCHW); | |||
| primitive_c->set_nearest_mode(mindspore::NearestMode::ROUND_HALF_DOWN); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "coordinate_transformation_mode") { | |||
| @@ -51,24 +48,24 @@ ops::PrimitiveC *OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| {"align_corners", mindspore::CoordinateTransformMode::ALIGN_CORNERS}, | |||
| {"asymmetric", mindspore::CoordinateTransformMode::ASYMMETRIC}}; | |||
| if (transform_map.find(onnx_node_attr.s()) != transform_map.end()) { | |||
| primitive_c->set_coordinate_transform_mode(transform_map[onnx_node_attr.s()]); | |||
| prim->set_coordinate_transform_mode(transform_map[onnx_node_attr.s()]); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport coordinate transform mode: " << attribute_name; | |||
| return nullptr; | |||
| } | |||
| } else if (attribute_name == "cubic_coeff_a") { | |||
| primitive_c->set_cubic_coeff(onnx_node_attr.f()); | |||
| prim->set_cubic_coeff(onnx_node_attr.f()); | |||
| } else if (attribute_name == "exclude_outside") { | |||
| primitive_c->set_exclude_outside(onnx_node_attr.i()); | |||
| prim->set_exclude_outside(onnx_node_attr.i()); | |||
| } else if (attribute_name == "extrapolation_value") { | |||
| primitive_c->set_extrapolation_value(onnx_node_attr.f()); | |||
| prim->set_extrapolation_value(onnx_node_attr.f()); | |||
| } else if (attribute_name == "mode") { | |||
| std::map<std::string, mindspore::ResizeMethod> resize_mode = { | |||
| {"nearest", mindspore::ResizeMethod::NEAREST}, | |||
| {"linear", mindspore::ResizeMethod::LINEAR}, | |||
| {"cubic", mindspore::ResizeMethod::CUBIC}, | |||
| }; | |||
| primitive_c->set_method(resize_mode[onnx_node_attr.s()]); | |||
| prim->set_method(resize_mode[onnx_node_attr.s()]); | |||
| } else if (attribute_name == "nearest_mode") { | |||
| std::map<std::string, mindspore::NearestMode> nearest_mode = { | |||
| {"round_prefer_floor", mindspore::NearestMode::ROUND_HALF_DOWN}, | |||
| @@ -76,11 +73,11 @@ ops::PrimitiveC *OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| {"floor", mindspore::NearestMode::FLOOR}, | |||
| {"ceil", mindspore::NearestMode::CEIL}, | |||
| }; | |||
| primitive_c->set_nearest_mode(nearest_mode[onnx_node_attr.s()]); | |||
| prim->set_nearest_mode(nearest_mode[onnx_node_attr.s()]); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser()); | |||
| @@ -21,13 +21,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Shape; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Shape failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Shape>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); | |||
| @@ -26,11 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::StridedSlice; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new StridedSlice failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::StridedSlice>(); | |||
| std::vector<int32_t> starts; | |||
| std::vector<int32_t> ends; | |||
| @@ -76,7 +72,7 @@ ops::PrimitiveC *OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| size = static_cast<int>(steps.size()); | |||
| } | |||
| if (size == -1) { | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| if (axes.empty()) { | |||
| for (size_t i = 0; i < starts.size(); ++i) { | |||
| @@ -87,11 +83,12 @@ ops::PrimitiveC *OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| steps.assign(starts.size(), 1); | |||
| } | |||
| primitive_c->AddAttr("starts", MakeValue(starts)); | |||
| primitive_c->AddAttr("axes", MakeValue(axes)); | |||
| primitive_c->AddAttr("ends", MakeValue(ends)); | |||
| primitive_c->AddAttr("steps", MakeValue(steps)); | |||
| return primitive_c; | |||
| prim->AddAttr("starts", MakeValue(starts)); | |||
| prim->AddAttr("axes", MakeValue(axes)); | |||
| prim->AddAttr("ends", MakeValue(ends)); | |||
| prim->AddAttr("steps", MakeValue(steps)); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); | |||
| @@ -21,11 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Softmax; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new SoftMax failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Softmax>(); | |||
| int64_t axis; | |||
| bool axis_is_def = true; | |||
| @@ -39,9 +35,9 @@ ops::PrimitiveC *OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| if (axis_is_def) { | |||
| axis = OnnxNodeParser::opset_version() >= 13 ? -1 : 1; | |||
| } | |||
| primitive_c->set_axis({axis}); | |||
| prim->set_axis({axis}); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::SpaceToDepth; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new SpaceToDepth failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::SpaceToDepth>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "blocksize") { | |||
| primitive_c->set_block_size(onnx_node_attr.i()); | |||
| prim->set_block_size(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser()); | |||
| @@ -23,31 +23,28 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Split; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Split failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Split>(); | |||
| primitive_c->set_axis(0); | |||
| prim->set_axis(0); | |||
| std::vector<int64_t> size_splits; | |||
| int64_t split_num = 0; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| primitive_c->set_axis(onnx_node_attr.i()); | |||
| prim->set_axis(onnx_node_attr.i()); | |||
| } else if (attribute_name == "split") { | |||
| size_splits.resize(onnx_node_attr.ints_size()); | |||
| std::copy(onnx_node_attr.ints().begin(), onnx_node_attr.ints().end(), size_splits.begin()); | |||
| primitive_c->set_size_splits(size_splits); | |||
| prim->set_size_splits(size_splits); | |||
| split_num = onnx_node_attr.ints_size(); | |||
| } | |||
| } | |||
| if (split_num == 0) { | |||
| split_num = onnx_node.output_size(); | |||
| } | |||
| primitive_c->set_output_num(split_num); | |||
| return primitive_c; | |||
| prim->set_output_num(split_num); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Squeeze; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Squeeze failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Squeeze>(); | |||
| std::vector<int64_t> axis; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -35,11 +31,11 @@ ops::PrimitiveC *OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| for (int i = 0; i < onnx_node_attr.ints().size(); ++i) { | |||
| axis.emplace_back(onnx_node_attr.ints(i)); | |||
| } | |||
| primitive_c->set_axis(axis); | |||
| prim->set_axis(axis); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); | |||
| @@ -21,13 +21,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::TileFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new TileFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::TileFusion>(); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); | |||
| @@ -21,20 +21,16 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::TopKFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new TopKFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::TopKFusion>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "k") { | |||
| primitive_c->AddAttr("k", MakeValue(static_cast<int32_t>(onnx_node_attr.i()))); | |||
| prim->AddAttr("k", MakeValue(static_cast<int32_t>(onnx_node_attr.i()))); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Transpose; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Transpose failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Transpose>(); | |||
| std::vector<int32_t> perm; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -36,11 +32,11 @@ ops::PrimitiveC *OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { | |||
| perm[i] = onnx_node_attr.ints(i); | |||
| } | |||
| primitive_c->AddAttr("perm", MakeValue(perm)); | |||
| prim->AddAttr("perm", MakeValue(perm)); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); | |||
| @@ -22,11 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto primitive_c = new (std::nothrow) ops::Unsqueeze; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Unsqueeze failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Unsqueeze>(); | |||
| std::vector<int64_t> axis; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -35,11 +31,11 @@ ops::PrimitiveC *OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| for (int i = 0; i < onnx_node_attr.ints().size(); ++i) { | |||
| axis.emplace_back(onnx_node_attr.ints(i)); | |||
| } | |||
| primitive_c->set_axis(axis); | |||
| prim->set_axis(axis); | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxUnsqueezeParser("Unsqueeze", new OnnxUnSqueezeParser()); | |||
| @@ -23,14 +23,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| // use bilinear method | |||
| auto primitive_c = new (std::nothrow) ops::Resize; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Resize failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Resize>(); | |||
| prim->set_method(mindspore::ResizeMethod::NEAREST); // use bilinear method | |||
| primitive_c->set_method(mindspore::ResizeMethod::NEAREST); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "mode") { | |||
| @@ -38,13 +34,13 @@ ops::PrimitiveC *OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, c | |||
| MS_LOG(ERROR) << "the UpSample mode don't support now."; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_method(onnx_node_attr.s() == "nearest" ? mindspore::ResizeMethod::NEAREST | |||
| : mindspore::ResizeMethod::LINEAR); | |||
| prim->set_method(onnx_node_attr.s() == "nearest" ? mindspore::ResizeMethod::NEAREST | |||
| : mindspore::ResizeMethod::LINEAR); | |||
| } | |||
| } | |||
| primitive_c->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC); | |||
| prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC); | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxUpsampleParser("Upsample", new OnnxUpsampleParser()); | |||
| @@ -25,20 +25,16 @@ namespace lite { | |||
| ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Activation(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Activation failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Activation>(); | |||
| if (tf_op.op() == "Relu") { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::RELU); | |||
| prim->set_activation_type(mindspore::ActivationType::RELU); | |||
| } else if (tf_op.op() == "Relu6") { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::RELU6); | |||
| prim->set_activation_type(mindspore::ActivationType::RELU6); | |||
| } else if (tf_op.op() == "Sigmoid") { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::SIGMOID); | |||
| prim->set_activation_type(mindspore::ActivationType::SIGMOID); | |||
| } else if (tf_op.op() == "Tanh") { | |||
| primitive_c->set_activation_type(mindspore::ActivationType::TANH); | |||
| prim->set_activation_type(mindspore::ActivationType::TANH); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | |||
| return nullptr; | |||
| @@ -49,7 +45,8 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "add op input failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); | |||
| @@ -44,89 +44,41 @@ ops::PrimitiveC *TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| if (tf_op.op() == "Add" || tf_op.op() == "AddV2") { | |||
| auto primitive_c = new (std::nothrow) ops::AddFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new AddFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::AddFusion>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Sub") { | |||
| auto primitive_c = new (std::nothrow) ops::SubFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new SubFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::SubFusion>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Mul") { | |||
| auto primitive_c = new (std::nothrow) ops::MulFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MulFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::MulFusion>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") { | |||
| auto primitive_c = new (std::nothrow) ops::DivFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new DivFusion failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::DivFusion>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Maximum") { | |||
| auto primitive_c = new (std::nothrow) ops::Maximum; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Maximum failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Maximum>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Minimum") { | |||
| auto primitive_c = new (std::nothrow) ops::Minimum; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Minimum failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Minimum>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Greater") { | |||
| auto primitive_c = new (std::nothrow) ops::Greater; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Greater failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Greater>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "GreaterEqual") { | |||
| auto primitive_c = new (std::nothrow) ops::GreaterEqual; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new GreaterEqual failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::GreaterEqual>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Less") { | |||
| auto primitive_c = new (std::nothrow) ops::Less; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Less failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Less>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "LessEqual") { | |||
| auto primitive_c = new (std::nothrow) ops::LessEqual; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LessEqual failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::LessEqual>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "Equal") { | |||
| auto primitive_c = new (std::nothrow) ops::Equal; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Equal failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::Equal>(); | |||
| return prim.release(); | |||
| } else if (tf_op.op() == "NotEqual") { | |||
| auto primitive_c = new (std::nothrow) ops::NotEqual; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new NotEqual failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| auto prim = std::make_unique<ops::NotEqual>(); | |||
| return prim.release(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -26,18 +26,14 @@ namespace lite { | |||
| ops::PrimitiveC *TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Assert; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "New Assert failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Assert>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { | |||
| MS_LOG(ERROR) << "The keep_dims attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_summarize((int64_t)(attr_value.i())); | |||
| prim->set_summarize((int64_t)(attr_value.i())); | |||
| *output_size = 0; // Assert not have output | |||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||
| @@ -47,7 +43,7 @@ ops::PrimitiveC *TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::BiasAdd; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new BiasAdd failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::BiasAdd>(); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -38,7 +34,7 @@ ops::PrimitiveC *TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser()); | |||
| @@ -26,18 +26,14 @@ namespace lite { | |||
| ops::PrimitiveC *TFCastParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Cast; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Cast failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Cast>(); | |||
| auto dst_type = TensorFlowUtils::ParseAttrDataType(tf_op, "DstT"); | |||
| if (dst_type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Get attr DstT failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type))); | |||
| prim->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type))); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK) { | |||
| @@ -45,7 +41,7 @@ ops::PrimitiveC *TFCastParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfCastParser("Cast", new TFCastParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFConcatParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Concat; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Concat failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Concat>(); | |||
| auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(tf_op.input_size() - 1)); | |||
| if (axis_node == nullptr) { | |||
| @@ -43,7 +39,7 @@ ops::PrimitiveC *TFConcatParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| auto tensor_proto = attr_value.tensor(); | |||
| primitive_c->set_axis(tensor_proto.int_val(0)); | |||
| prim->set_axis(tensor_proto.int_val(0)); | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size() - 1; ++i) { | |||
| @@ -53,7 +49,7 @@ ops::PrimitiveC *TFConcatParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfConcatV2Parser("ConcatV2", new TFConcatParser()); | |||
| @@ -27,14 +27,10 @@ namespace lite { | |||
| ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Conv2DFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Conv2DFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Conv2DFusion>(); | |||
| primitive_c->set_pad({0, 0, 0, 0}); | |||
| primitive_c->set_group(1); | |||
| prim->set_pad({0, 0, 0, 0}); | |||
| prim->set_group(1); | |||
| // parse format | |||
| auto format = TensorFlowUtils::ParseNodeFormat(tf_op); | |||
| @@ -42,7 +38,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_format(format); | |||
| prim->set_format(format); | |||
| // parse kernel | |||
| auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| @@ -55,9 +51,9 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "parse kernels failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_kernel_size({kernels[0], kernels[1]}); | |||
| primitive_c->set_out_channel(kernels[3]); | |||
| primitive_c->set_in_channel(kernels[2]); | |||
| prim->set_kernel_size({kernels[0], kernels[1]}); | |||
| prim->set_out_channel(kernels[3]); | |||
| prim->set_in_channel(kernels[2]); | |||
| // parse stride | |||
| std::vector<int64_t> strides(2); | |||
| @@ -65,7 +61,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "parse strides failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_stride(strides); | |||
| prim->set_stride(strides); | |||
| // parse dilation | |||
| std::vector<int64_t> dilations(2); | |||
| @@ -73,11 +69,11 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "parse dilations failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_dilation(dilations); | |||
| prim->set_dilation(dilations); | |||
| // parse pad | |||
| auto padMode = ParsePadMode(tf_op); | |||
| primitive_c->set_pad_mode(padMode); | |||
| auto pad_mode = ParsePadMode(tf_op); | |||
| prim->set_pad_mode(pad_mode); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -85,7 +81,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfConvParser("Conv2D", new TFConvParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::ExpandDims; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ExpandDims failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ExpandDims>(); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -38,7 +34,7 @@ ops::PrimitiveC *TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfExpandDimsParser("ExpandDims", new TFExpandDimsParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Gather; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Gather failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Gather>(); | |||
| int batchDims = 0; | |||
| tensorflow::AttrValue attr_value; | |||
| @@ -72,7 +68,7 @@ ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| if (batchDims != 0 && !axis_is_set) { | |||
| axis = batchDims; | |||
| } | |||
| primitive_c->AddAttr("axis", MakeValue(axis)); | |||
| prim->AddAttr("axis", MakeValue(axis)); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -80,7 +76,7 @@ ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfGatherV2Parser("Gather", new TFGatherParser()); | |||
| @@ -27,16 +27,12 @@ ops::PrimitiveC *TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| if (tf_op.op() == "LogicalAnd") { | |||
| auto primitive_c = new (std::nothrow) ops::LogicalAnd; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new LogicalAnd failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::LogicalAnd>(); | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << "only LogicalAnd is supported now"; | |||
| return nullptr; | |||
| @@ -26,18 +26,14 @@ namespace lite { | |||
| ops::PrimitiveC *TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::MatMul; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new MatMul failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::MatMul>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) { | |||
| primitive_c->set_transpose_a(attr_value.b()); | |||
| prim->set_transpose_a(attr_value.b()); | |||
| } | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) { | |||
| primitive_c->set_transpose_b(attr_value.b()); | |||
| prim->set_transpose_b(attr_value.b()); | |||
| } | |||
| *output_size = 1; | |||
| @@ -46,7 +42,7 @@ ops::PrimitiveC *TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser()); | |||
| @@ -26,18 +26,14 @@ namespace lite { | |||
| ops::PrimitiveC *TFPackParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Stack; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Stack failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Stack>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) { | |||
| MS_LOG(ERROR) << "The axis attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_axis({attr_value.i()}); | |||
| prim->set_axis({attr_value.i()}); | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||
| @@ -47,7 +43,7 @@ ops::PrimitiveC *TFPackParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfPackParser("Pack", new TFPackParser()); | |||
| @@ -32,30 +32,26 @@ ops::PrimitiveC *TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| auto primitive = new (std::nothrow) ops::Range; | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New RaggedRange failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Range>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) { | |||
| MS_LOG(ERROR) << "The starts attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_start(static_cast<int64_t>(attr_value.i())); | |||
| prim->set_start(static_cast<int64_t>(attr_value.i())); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) { | |||
| MS_LOG(ERROR) << "The limits attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_limit(static_cast<int64_t>(attr_value.i())); | |||
| prim->set_limit(static_cast<int64_t>(attr_value.i())); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) { | |||
| MS_LOG(ERROR) << "The deltas attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_delta(static_cast<int64_t>(attr_value.i())); | |||
| prim->set_delta(static_cast<int64_t>(attr_value.i())); | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| @@ -63,7 +59,7 @@ ops::PrimitiveC *TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "add op input is failed!"; | |||
| return nullptr; | |||
| } | |||
| return primitive; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser()); | |||
| @@ -33,23 +33,19 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| auto primitive_c = new (std::nothrow) ops::Range; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "New Range failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Range>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) { | |||
| primitive_c->set_start(static_cast<int64_t>(attr_value.i())); | |||
| prim->set_start(static_cast<int64_t>(attr_value.i())); | |||
| } | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) { | |||
| primitive_c->set_limit(static_cast<int64_t>(attr_value.i())); | |||
| prim->set_limit(static_cast<int64_t>(attr_value.i())); | |||
| } | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) { | |||
| primitive_c->set_delta(static_cast<int64_t>(attr_value.i())); | |||
| prim->set_delta(static_cast<int64_t>(attr_value.i())); | |||
| } | |||
| *output_size = 1; | |||
| @@ -60,7 +56,8 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "add op input failed!"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser()); | |||
| @@ -26,24 +26,20 @@ namespace lite { | |||
| ops::PrimitiveC *TFReduceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::ReduceFusion; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new ReduceFusion failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ReduceFusion>(); | |||
| if (tf_op.op() == "Sum") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Sum); | |||
| } else if (tf_op.op() == "Max") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Max); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Max); | |||
| } else if (tf_op.op() == "Min") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Min); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Min); | |||
| } else if (tf_op.op() == "Mean") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Mean); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Mean); | |||
| } else if (tf_op.op() == "Prod") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_Prod); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_Prod); | |||
| } else if (tf_op.op() == "All") { | |||
| primitive_c->set_mode(mindspore::ReduceMode::Reduce_All); | |||
| prim->set_mode(mindspore::ReduceMode::Reduce_All); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op(); | |||
| return nullptr; | |||
| @@ -59,7 +55,7 @@ ops::PrimitiveC *TFReduceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "the keep_dims attr of reduce should be bool type"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_keep_dims(attr_value.b()); | |||
| prim->set_keep_dims(attr_value.b()); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -67,7 +63,7 @@ ops::PrimitiveC *TFReduceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfSumParser("Sum", new TFReduceParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Reshape; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Reshape failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Reshape>(); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -38,7 +34,7 @@ ops::PrimitiveC *TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfReshapeParser("Reshape", new TFReshapeParser()); | |||
| @@ -32,23 +32,19 @@ ops::PrimitiveC *TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto primitive = new (std::nothrow) ops::ReverseSequence; | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New ReverseSequenceParser failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::ReverseSequence>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "batch_dim", &attr_value)) { | |||
| MS_LOG(ERROR) << "The batch_dim attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_batch_dim(attr_value.i()); | |||
| prim->set_batch_dim(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "seq_dim", &attr_value)) { | |||
| MS_LOG(ERROR) << "The seq_dim attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_seq_dim(attr_value.i()); | |||
| prim->set_seq_dim(attr_value.i()); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| @@ -56,7 +52,7 @@ ops::PrimitiveC *TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op | |||
| return nullptr; | |||
| } | |||
| return primitive; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFRoundParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Round; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Round failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Round>(); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK) { | |||
| @@ -38,7 +34,7 @@ ops::PrimitiveC *TFRoundParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfRoundParser("Round", new TFRoundParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFShapeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Shape; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Shape failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Shape>(); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK) { | |||
| @@ -38,7 +34,7 @@ ops::PrimitiveC *TFShapeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfShapeParser("Shape", new TFShapeParser()); | |||
| @@ -26,19 +26,15 @@ namespace lite { | |||
| ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Split; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Split failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Split>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "num_split", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute num_split should be specified"; | |||
| return nullptr; | |||
| } | |||
| auto numberSplit = attr_value.i(); | |||
| primitive_c->set_output_num(numberSplit); | |||
| auto number_split = attr_value.i(); | |||
| prim->set_output_num(number_split); | |||
| int split_dim_index = 2; | |||
| int input_index = 0; | |||
| @@ -57,7 +53,7 @@ ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| auto splitDim = attr_value.tensor().int_val(0); | |||
| primitive_c->set_axis(splitDim); | |||
| prim->set_axis(splitDim); | |||
| if (tf_op.op() == "SplitV") { | |||
| auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| @@ -80,16 +76,16 @@ ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_size_splits(sizeSplits); | |||
| prim->set_size_splits(sizeSplits); | |||
| } | |||
| *output_size = numberSplit; | |||
| *output_size = number_split; | |||
| if (AddOpInput(tf_op, input_index, inputs) != RET_OK) { | |||
| MS_LOG(ERROR) << "add op input failed"; | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfSplitParser("Split", new TFSplitParser()); | |||
| @@ -27,11 +27,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::Squeeze; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new Squeeze failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::Squeeze>(); | |||
| std::vector<int64_t> axis; | |||
| tensorflow::AttrValue attr_value; | |||
| @@ -43,7 +39,7 @@ ops::PrimitiveC *TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| for (int i = 0; i < dims.i_size(); ++i) { | |||
| axis.push_back(dims.i(i)); | |||
| } | |||
| primitive_c->set_axis(axis); | |||
| prim->set_axis(axis); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK) { | |||
| @@ -51,7 +47,7 @@ ops::PrimitiveC *TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfSqueezeParser("Squeeze", new TFSqueezeParser()); | |||
| @@ -27,42 +27,38 @@ namespace lite { | |||
| ops::PrimitiveC *TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive_c = new (std::nothrow) ops::StridedSlice; | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "new StridedSlice failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::StridedSlice>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "begin_mask", &attr_value)) { | |||
| MS_LOG(ERROR) << "The begin_mask attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_begin_mask(attr_value.i()); | |||
| prim->set_begin_mask(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "end_mask", &attr_value)) { | |||
| MS_LOG(ERROR) << "The end_mask attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_end_mask(attr_value.i()); | |||
| prim->set_end_mask(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "ellipsis_mask", &attr_value)) { | |||
| MS_LOG(ERROR) << "The ellipsis_mask attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_ellipsis_mask(attr_value.i()); | |||
| prim->set_ellipsis_mask(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "new_axis_mask", &attr_value)) { | |||
| MS_LOG(ERROR) << "The new_axis_mask attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_new_axis_mask(attr_value.i()); | |||
| prim->set_new_axis_mask(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "shrink_axis_mask", &attr_value)) { | |||
| MS_LOG(ERROR) << "The shrink_axis_mask attr should be specified"; | |||
| return nullptr; | |||
| } | |||
| primitive_c->set_shrink_axis_mask(attr_value.i()); | |||
| prim->set_shrink_axis_mask(attr_value.i()); | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK || | |||
| @@ -71,7 +67,7 @@ ops::PrimitiveC *TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return nullptr; | |||
| } | |||
| return primitive_c; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser()); | |||
| @@ -26,11 +26,7 @@ namespace lite { | |||
| ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto primitive = new (std::nothrow) ops::TensorListFromTensor; | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New TensorListFromTensor failed"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<ops::TensorListFromTensor>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||
| @@ -42,7 +38,7 @@ ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef & | |||
| MS_LOG(ERROR) << "tensor_list_from_tensor element dtype must be known type"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_element_dtype((int64_t)(type)); | |||
| prim->set_element_dtype((int64_t)(type)); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { | |||
| MS_LOG(ERROR) << "The shape_type attr should be specified"; | |||
| @@ -53,7 +49,7 @@ ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef & | |||
| MS_LOG(ERROR) << "tensor_list_from_tensor shape type must be known type"; | |||
| return nullptr; | |||
| } | |||
| primitive->set_shape_type((int64_t)(type)); | |||
| prim->set_shape_type((int64_t)(type)); | |||
| *output_size = 1; | |||
| for (int i = 0; i < 2; ++i) { | |||
| @@ -63,7 +59,7 @@ ops::PrimitiveC *TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef & | |||
| } | |||
| } | |||
| return primitive; | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorListFromTensorParser("TensorListFromTensor", new TFTensorListFromTensorParser()); | |||