Browse Source

!13036 [lite]fix train bug

From: @xu_anyue
Reviewed-by: @hangangqiang,@jpc_chenjianping
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
7ba4f7a1dc
4 changed files with 10 additions and 5 deletions
  1. +6
    -0
      mindspore/lite/src/ops/ops_utils.cc
  2. +3
    -0
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  3. +1
    -1
      mindspore/lite/tools/converter/anf_transform.cc
  4. +0
    -4
      mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc

+ 6
- 0
mindspore/lite/src/ops/ops_utils.cc View File

@@ -608,6 +608,10 @@ schema::PrimitiveT *SpaceToDepthPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SpaceToDepth>>(node); auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SpaceToDepth>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
} }
schema::PrimitiveT *SparseSoftmaxCrossEntropyPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseSoftmaxCrossEntropy>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *SparseToDensePrimitiveCreator(const AnfNodePtr &node) { schema::PrimitiveT *SparseToDensePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseToDense>>(node); auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseToDense>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@@ -876,6 +880,8 @@ RegistryMSOps g_softmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry("SoftmaxCr
RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator); RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator);
RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator); RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator);
RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator); RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator);
RegistryMSOps g_sparseSoftmaxCrossEntropyPrimitiveCreatorRegistry("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyPrimitiveCreator);
RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator);
RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator);
RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator);


+ 3
- 0
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -770,6 +770,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
} else if (value->isa<FuncGraph>()) { } else if (value->isa<FuncGraph>()) {
MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph";
return RET_OK; return RET_OK;
} else if (value->isa<Monad>()) {
MS_LOG(INFO) << "value is a monad.";
return RET_OK;
} else { } else {
MS_LOG(ERROR) << "Not support value type , need add support."; MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR; return RET_ERROR;


+ 1
- 1
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -143,8 +143,8 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt
int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) { const converter::Flags *config) {
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
if (!config->trainModel) { if (!config->trainModel) {
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
auto inne_context_ptr = std::make_shared<lite::InnerContext>(); auto inne_context_ptr = std::make_shared<lite::InnerContext>();
inne_context_ptr->Init(); inne_context_ptr->Init();
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr)); const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));


+ 0
- 4
mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc View File

@@ -98,10 +98,6 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) { if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
status = ReplaceOp(node, manager); status = ReplaceOp(node, manager);
} }
if (CheckPrimitiveType(node, prim::kPrimDepend) ||
CheckPrimitiveType(node, prim::kPrimControlDepend)) { // ControlDepend delete next version.
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
status = ReplaceTupleGetItem(node, manager); status = ReplaceTupleGetItem(node, manager);
} }


Loading…
Cancel
Save