From ae64426d6cbfc4a47cee4d0705549a0dd0e0e424 Mon Sep 17 00:00:00 2001 From: zhaodezan Date: Thu, 11 Mar 2021 14:41:42 +0800 Subject: [PATCH] fix floor_div_parser and random_standard_norm bug --- mindspore/lite/src/ops/ops_utils.cc | 6 ++++++ .../converter/parser/tf/tf_arithmetic_parser.cc | 16 ++++++++++++++++ .../converter/parser/tf/tf_arithmetic_parser.h | 10 ++++++++++ 3 files changed, 32 insertions(+) diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 34c4a86aa7..f7bfe6f136 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -488,6 +488,10 @@ schema::PrimitiveT *RangePrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *RandomStandardNormalPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *RankPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -843,6 +847,8 @@ RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFu RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator); RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator); RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator); +RegistryMSOps g_RandomStandardNormalPrimitiveCreatorRegistry("RandomStandardNormal", + RandomStandardNormalPrimitiveCreator); RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator); RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator); RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc index 588bbd95e1..ceb592bbde 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc @@ -37,6 +37,7 @@ #include "ops/ceil.h" #include "ops/fusion/exp_fusion.h" #include "ops/floor.h" +#include "ops/floor_div.h" #include "ops/floor_mod.h" #include "ops/log.h" #include "ops/sqrt.h" @@ -299,6 +300,20 @@ ops::PrimitiveC *TFFloorParser::Parse(const tensorflow::NodeDef &tf_op, return prim.release(); } +ops::PrimitiveC *TFFloorDivParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + auto prim = std::make_unique(); + + *output_size = 1; + if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { + MS_LOG(ERROR) << "add op input failed"; + return nullptr; + } + + return prim.release(); +} + ops::PrimitiveC *TFFloorModParser::Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, std::vector *inputs, int *output_size) { @@ -435,6 +450,7 @@ TFNodeRegistrar g_tfSquareParser("Square", new TFSquareParser()); TFNodeRegistrar g_tfCeilParser("Ceil", new TFCeilParser()); TFNodeRegistrar g_tfExpParser("Exp", new TFExpParser()); TFNodeRegistrar g_tfFloorParser("Floor", new TFFloorParser()); +TFNodeRegistrar g_tfFloorDivParser("FloorDiv", new TFFloorDivParser()); TFNodeRegistrar g_tfFloorModParser("FloorMod", new TFFloorModParser()); TFNodeRegistrar g_tfLogParser("Log", new TFLogParser()); TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFSqrtParser()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h index 305e761ace..98cf5551d4 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h @@ -204,6 +204,16 @@ class TFFloorParser : public TFNodeParser { std::vector *inputs, int *output_size) override; }; +class TFFloorDivParser : public TFNodeParser { + public: + TFFloorDivParser() = default; + ~TFFloorDivParser() override = default; + + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; +}; + class TFFloorModParser : public TFNodeParser { public: TFFloorModParser() = default;