From 7c66c1685be57a47a4e8fcc17564c8b3e3a2b203 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Mon, 26 Apr 2021 21:05:38 +0800 Subject: [PATCH] fix train bug --- .../lite/tools/anf_exporter/fetch_content.cc | 76 +++++++----- .../graph/redundant_op_remove_pass.cc | 108 +++++++++++++----- 2 files changed, 125 insertions(+), 59 deletions(-) diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc index fd86870372..5923fc5c53 100644 --- a/mindspore/lite/tools/anf_exporter/fetch_content.cc +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -110,6 +110,40 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh return RET_OK; } +int FetchFromDefaultParam(const ParameterPtr ¶m_node, DataInfo *data_info) { + MS_ASSERT(param_node != nullptr && data_info != nullptr); + ShapeVector shape_vector; + TypeId data_type; + auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); + if (status != RET_OK) { + MS_LOG(ERROR) << "get data type and shape from param node failed."; + return RET_ERROR; + } + data_info->data_type_ = data_type; + auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); + size_t offset = 0; + if (!shape_vector.empty() && data_type == kObjectTypeString) { + status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); + if (status != RET_OK) { + MS_LOG(ERROR) << "get shape vector from string tensor failed."; + return RET_ERROR; + } + } + std::vector dims(shape_vector.begin(), shape_vector.end()); + data_info->shape_ = dims; + if (tensor_info != nullptr && tensor_info->Size() != 0) { + if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) { + data_info->data_.resize(tensor_info->Size() - offset); + if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(), + static_cast(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { + MS_LOG(ERROR) << "memcpy_s failed."; + return RET_ERROR; + } + } + } + return RET_OK; +} + int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type, bool train_flag, DataInfo *data_info) { MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); @@ -230,10 +264,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F DataInfo *data_info) { MS_ASSERT(cnode != nullptr && data_info != nullptr); auto param_node = cnode->input(index)->cast(); + if (param_node == nullptr) { + MS_LOG(ERROR) << "input node is not parameter node."; + return RET_ERROR; + } data_info->format_ = GetFormatByFmk(fmk_type); if (data_info->format_ < 0) { MS_LOG(ERROR) << "don't support current fmk: " << fmk_type; - return lite::RET_ERROR; + return RET_ERROR; } if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; @@ -245,38 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { data_info->format_ = GetValue(prim->GetAttr(opt::kWeightFormat)); } - ShapeVector shape_vector; - TypeId data_type; - auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); - if (status != RET_OK) { - MS_LOG(ERROR) << "get data type and shape from param node failed."; + if (FetchFromDefaultParam(param_node, data_info) != RET_OK) { + MS_LOG(ERROR) << "fetch information from default param failed."; return RET_ERROR; } - data_info->data_type_ = data_type; - auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); - size_t offset = 0; - if (!shape_vector.empty() && data_type == kObjectTypeString) { - status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); - if (status != RET_OK) { - MS_LOG(ERROR) << "get shape vector from string tensor failed."; - return RET_ERROR; - } - } - std::vector dims(shape_vector.begin(), shape_vector.end()); - data_info->shape_ = dims; - if (tensor_info != nullptr && tensor_info->Size() != 0) { - if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) { - data_info->data_.resize(tensor_info->Size() - offset); - if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(), - static_cast(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { - MS_LOG(ERROR) << "memcpy_s failed."; - return RET_ERROR; - } - } - } QuantParamHolderPtr quant_param_holder = prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast(); - if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) { + if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && + data_info->data_type_ == kNumberTypeInt8) { data_info->enable_huffman_code_ = true; } data_info->node_type_ = NodeType_ValueNode; @@ -287,6 +301,10 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy DataInfo *data_info) { MS_ASSERT(cnode != nullptr && data_info != nullptr); auto value_node = cnode->input(index)->cast(); + if (value_node == nullptr) { + MS_LOG(ERROR) << "input node is not value node."; + return RET_ERROR; + } auto value = value_node->value(); int ret = RET_OK; auto prim = GetValueNode(cnode->input(0)); diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index b6aaf5151b..f583704237 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -18,26 +18,91 @@ #include #include #include "include/errorcode.h" +#include "ops/depend.h" #include "ops/make_tuple.h" namespace mindspore::opt { namespace { constexpr size_t kInputDoubleNum = 2; constexpr size_t kInputTripleNum = 3; -void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector *inputs) { - MS_ASSERT(anf_node != nullptr); - MS_ASSERT(inputs != nullptr); - auto cnode = anf_node->cast(); - if (cnode == nullptr) { - return; +int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + auto first_input = cnode->input(1); + auto second_input = cnode->input(2); + AnfNodePtr must_monad = nullptr; + AnfNodePtr not_must_monad = nullptr; + if (utils::isa(first_input)) { + auto value_node = first_input->cast(); + MS_ASSERT(value_node->value() != nullptr); + if (utils::isa(value_node->value())) { + must_monad = first_input; + not_must_monad = second_input; + } } - for (size_t i = 1; i < cnode->size(); ++i) { - if (cnode->input(i)->isa()) { - inputs->push_back(cnode->input(i)); + if (utils::isa(second_input)) { + auto value_node = second_input->cast(); + MS_ASSERT(value_node->value() != nullptr); + if (utils::isa(value_node->value())) { + must_monad = second_input; + not_must_monad = first_input; } } + if (must_monad == nullptr) { + return lite::RET_NO_CHANGE; + } + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + if (!utils::isa(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) { + manager->Replace(cnode, must_monad); + } else { + manager->Replace(cnode, not_must_monad); + } + return lite::RET_OK; +} + +int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + AnfNodePtr pre_node = cnode->input(1); + AnfNodePtr post_node = cnode->input(2); + if (!pre_node_is_first) { + pre_node = cnode->input(2); + post_node = cnode->input(1); + } + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + auto node_users = manager->node_users()[pre_node]; + auto iter = + std::find_if(node_users.begin(), node_users.end(), + [&post_node](const std::pair &post_pair) { return post_pair.first == post_node; }); + if (iter == node_users.end()) { + return lite::RET_NO_CHANGE; + } + auto tr = manager->Transact(); + tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared())); + tr.Commit(); + auto depend_prim = std::make_shared(); + auto depend_node = func_graph->NewCNode(depend_prim, {post_node, pre_node}); + depend_node->set_fullname_with_scope(cnode->fullname_with_scope()); + manager->Replace(cnode, depend_node); + return lite::RET_OK; +} + +int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) { + return lite::RET_OK; + } + if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) { + return lite::RET_OK; + } + auto make_tuple_prim = NewValueNode(std::make_shared()); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + manager->Replace(cnode->input(0), make_tuple_prim); + return lite::RET_OK; } } // namespace + int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { if (!utils::isa(anf_node)) { MS_LOG(DEBUG) << "anf node is node a cnode."; @@ -73,28 +138,11 @@ int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, return lite::RET_NO_CHANGE; } auto cnode = anf_node->cast(); - auto inputs = cnode->inputs(); - std::vector new_inputs; - for (size_t i = 1; i < inputs.size(); ++i) { - if (!inputs[i]->isa()) { - continue; - } - if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) { - FetchCNodeFromMakeTuple(inputs[i], &new_inputs); - continue; - } - new_inputs.push_back(inputs[i]); + if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) { + return lite::RET_OK; } - for (auto &node : new_inputs) { - func_graph->get_return()->add_input(node); - } - auto value = std::make_shared(); - bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value)); - if (!replace_succ) { - MS_LOG(ERROR) << "replace redundant op failed."; - return lite::RET_ERROR; - } - return RET_OK; + // both of two inputs are not monad, but have dependency. + return ProcessInputHaveDependency(func_graph, cnode); } int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {