From: @zhaodezan Reviewed-by: @jpc_chenjianping,@zhanghaibo5 Signed-off-by: @jpc_chenjianpingtags/v1.2.0-rc1
| @@ -488,6 +488,10 @@ schema::PrimitiveT *RangePrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Range>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Range>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *RandomStandardNormalPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::RandomStandardNormal>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *RankPrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *RankPrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Rank>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Rank>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -843,6 +847,8 @@ RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFu | |||||
| RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator); | RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator); | ||||
| RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator); | RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator); | ||||
| RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator); | RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator); | ||||
| RegistryMSOps g_RandomStandardNormalPrimitiveCreatorRegistry("RandomStandardNormal", | |||||
| RandomStandardNormalPrimitiveCreator); | |||||
| RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator); | RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator); | ||||
| RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator); | RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator); | ||||
| RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator); | RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator); | ||||
| @@ -37,6 +37,7 @@ | |||||
| #include "ops/ceil.h" | #include "ops/ceil.h" | ||||
| #include "ops/fusion/exp_fusion.h" | #include "ops/fusion/exp_fusion.h" | ||||
| #include "ops/floor.h" | #include "ops/floor.h" | ||||
| #include "ops/floor_div.h" | |||||
| #include "ops/floor_mod.h" | #include "ops/floor_mod.h" | ||||
| #include "ops/log.h" | #include "ops/log.h" | ||||
| #include "ops/sqrt.h" | #include "ops/sqrt.h" | ||||
| @@ -299,6 +300,20 @@ ops::PrimitiveC *TFFloorParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| return prim.release(); | return prim.release(); | ||||
| } | } | ||||
| ops::PrimitiveC *TFFloorDivParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| auto prim = std::make_unique<ops::FloorDiv>(); | |||||
| *output_size = 1; | |||||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||||
| MS_LOG(ERROR) << "add op input failed"; | |||||
| return nullptr; | |||||
| } | |||||
| return prim.release(); | |||||
| } | |||||
| ops::PrimitiveC *TFFloorModParser::Parse(const tensorflow::NodeDef &tf_op, | ops::PrimitiveC *TFFloorModParser::Parse(const tensorflow::NodeDef &tf_op, | ||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | ||||
| std::vector<std::string> *inputs, int *output_size) { | std::vector<std::string> *inputs, int *output_size) { | ||||
| @@ -435,6 +450,7 @@ TFNodeRegistrar g_tfSquareParser("Square", new TFSquareParser()); | |||||
| TFNodeRegistrar g_tfCeilParser("Ceil", new TFCeilParser()); | TFNodeRegistrar g_tfCeilParser("Ceil", new TFCeilParser()); | ||||
| TFNodeRegistrar g_tfExpParser("Exp", new TFExpParser()); | TFNodeRegistrar g_tfExpParser("Exp", new TFExpParser()); | ||||
| TFNodeRegistrar g_tfFloorParser("Floor", new TFFloorParser()); | TFNodeRegistrar g_tfFloorParser("Floor", new TFFloorParser()); | ||||
| TFNodeRegistrar g_tfFloorDivParser("FloorDiv", new TFFloorDivParser()); | |||||
| TFNodeRegistrar g_tfFloorModParser("FloorMod", new TFFloorModParser()); | TFNodeRegistrar g_tfFloorModParser("FloorMod", new TFFloorModParser()); | ||||
| TFNodeRegistrar g_tfLogParser("Log", new TFLogParser()); | TFNodeRegistrar g_tfLogParser("Log", new TFLogParser()); | ||||
| TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFSqrtParser()); | TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFSqrtParser()); | ||||
| @@ -204,6 +204,16 @@ class TFFloorParser : public TFNodeParser { | |||||
| std::vector<std::string> *inputs, int *output_size) override; | std::vector<std::string> *inputs, int *output_size) override; | ||||
| }; | }; | ||||
| class TFFloorDivParser : public TFNodeParser { | |||||
| public: | |||||
| TFFloorDivParser() = default; | |||||
| ~TFFloorDivParser() override = default; | |||||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| class TFFloorModParser : public TFNodeParser { | class TFFloorModParser : public TFNodeParser { | ||||
| public: | public: | ||||
| TFFloorModParser() = default; | TFFloorModParser() = default; | ||||