Browse Source

!13152 fix floor_div_parser and random_standard_norm bug

From: @zhaodezan
Reviewed-by: @jpc_chenjianping,@zhanghaibo5
Signed-off-by: @jpc_chenjianping
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a5b7617876
3 changed files with 32 additions and 0 deletions
  1. +6
    -0
      mindspore/lite/src/ops/ops_utils.cc
  2. +16
    -0
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc
  3. +10
    -0
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h

+ 6
- 0
mindspore/lite/src/ops/ops_utils.cc View File

@@ -488,6 +488,10 @@ schema::PrimitiveT *RangePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Range>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *RandomStandardNormalPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::RandomStandardNormal>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *RankPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Rank>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@@ -843,6 +847,8 @@ RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFu
RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator);
RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator);
RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator);
RegistryMSOps g_RandomStandardNormalPrimitiveCreatorRegistry("RandomStandardNormal",
RandomStandardNormalPrimitiveCreator);
RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator);
RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator);
RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator);


+ 16
- 0
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc View File

@@ -37,6 +37,7 @@
#include "ops/ceil.h"
#include "ops/fusion/exp_fusion.h"
#include "ops/floor.h"
#include "ops/floor_div.h"
#include "ops/floor_mod.h"
#include "ops/log.h"
#include "ops/sqrt.h"
@@ -299,6 +300,20 @@ ops::PrimitiveC *TFFloorParser::Parse(const tensorflow::NodeDef &tf_op,
return prim.release();
}

ops::PrimitiveC *TFFloorDivParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::FloorDiv>();

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}

return prim.release();
}

ops::PrimitiveC *TFFloorModParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
@@ -435,6 +450,7 @@ TFNodeRegistrar g_tfSquareParser("Square", new TFSquareParser());
TFNodeRegistrar g_tfCeilParser("Ceil", new TFCeilParser());
TFNodeRegistrar g_tfExpParser("Exp", new TFExpParser());
TFNodeRegistrar g_tfFloorParser("Floor", new TFFloorParser());
TFNodeRegistrar g_tfFloorDivParser("FloorDiv", new TFFloorDivParser());
TFNodeRegistrar g_tfFloorModParser("FloorMod", new TFFloorModParser());
TFNodeRegistrar g_tfLogParser("Log", new TFLogParser());
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFSqrtParser());


+ 10
- 0
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h View File

@@ -204,6 +204,16 @@ class TFFloorParser : public TFNodeParser {
std::vector<std::string> *inputs, int *output_size) override;
};

class TFFloorDivParser : public TFNodeParser {
public:
TFFloorDivParser() = default;
~TFFloorDivParser() override = default;

ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};

class TFFloorModParser : public TFNodeParser {
public:
TFFloorModParser() = default;


Loading…
Cancel
Save