From: @xu_anyue Reviewed-by: @HilbertDavid,@jpc_chenjianping Signed-off-by: @jpc_chenjianpingpull/15725/MERGE
| @@ -110,6 +110,40 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int FetchFromDefaultParam(const ParameterPtr ¶m_node, DataInfo *data_info) { | |||||
| MS_ASSERT(param_node != nullptr && data_info != nullptr); | |||||
| ShapeVector shape_vector; | |||||
| TypeId data_type; | |||||
| auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get data type and shape from param node failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| data_info->data_type_ = data_type; | |||||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()); | |||||
| size_t offset = 0; | |||||
| if (!shape_vector.empty() && data_type == kObjectTypeString) { | |||||
| status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get shape vector from string tensor failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||||
| data_info->shape_ = dims; | |||||
| if (tensor_info != nullptr && tensor_info->Size() != 0) { | |||||
| if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) { | |||||
| data_info->data_.resize(tensor_info->Size() - offset); | |||||
| if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(), | |||||
| static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type, | int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type, | ||||
| bool train_flag, DataInfo *data_info) { | bool train_flag, DataInfo *data_info) { | ||||
| MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | ||||
| @@ -230,10 +264,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F | |||||
| DataInfo *data_info) { | DataInfo *data_info) { | ||||
| MS_ASSERT(cnode != nullptr && data_info != nullptr); | MS_ASSERT(cnode != nullptr && data_info != nullptr); | ||||
| auto param_node = cnode->input(index)->cast<ParameterPtr>(); | auto param_node = cnode->input(index)->cast<ParameterPtr>(); | ||||
| if (param_node == nullptr) { | |||||
| MS_LOG(ERROR) << "input node is not parameter node."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| data_info->format_ = GetFormatByFmk(fmk_type); | data_info->format_ = GetFormatByFmk(fmk_type); | ||||
| if (data_info->format_ < 0) { | if (data_info->format_ < 0) { | ||||
| MS_LOG(ERROR) << "don't support current fmk: " << fmk_type; | MS_LOG(ERROR) << "don't support current fmk: " << fmk_type; | ||||
| return lite::RET_ERROR; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { | if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { | ||||
| MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; | MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; | ||||
| @@ -245,38 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F | |||||
| if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { | if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { | ||||
| data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)); | data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)); | ||||
| } | } | ||||
| ShapeVector shape_vector; | |||||
| TypeId data_type; | |||||
| auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get data type and shape from param node failed."; | |||||
| if (FetchFromDefaultParam(param_node, data_info) != RET_OK) { | |||||
| MS_LOG(ERROR) << "fetch information from default param failed."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| data_info->data_type_ = data_type; | |||||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()); | |||||
| size_t offset = 0; | |||||
| if (!shape_vector.empty() && data_type == kObjectTypeString) { | |||||
| status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get shape vector from string tensor failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||||
| data_info->shape_ = dims; | |||||
| if (tensor_info != nullptr && tensor_info->Size() != 0) { | |||||
| if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) { | |||||
| data_info->data_.resize(tensor_info->Size() - offset); | |||||
| if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(), | |||||
| static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| QuantParamHolderPtr quant_param_holder = | QuantParamHolderPtr quant_param_holder = | ||||
| prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | ||||
| if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) { | |||||
| if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && | |||||
| data_info->data_type_ == kNumberTypeInt8) { | |||||
| data_info->enable_huffman_code_ = true; | data_info->enable_huffman_code_ = true; | ||||
| } | } | ||||
| data_info->node_type_ = NodeType_ValueNode; | data_info->node_type_ = NodeType_ValueNode; | ||||
| @@ -287,6 +301,10 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy | |||||
| DataInfo *data_info) { | DataInfo *data_info) { | ||||
| MS_ASSERT(cnode != nullptr && data_info != nullptr); | MS_ASSERT(cnode != nullptr && data_info != nullptr); | ||||
| auto value_node = cnode->input(index)->cast<ValueNodePtr>(); | auto value_node = cnode->input(index)->cast<ValueNodePtr>(); | ||||
| if (value_node == nullptr) { | |||||
| MS_LOG(ERROR) << "input node is not value node."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto value = value_node->value(); | auto value = value_node->value(); | ||||
| int ret = RET_OK; | int ret = RET_OK; | ||||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | ||||
| @@ -18,26 +18,91 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "ops/depend.h" | |||||
| #include "ops/make_tuple.h" | #include "ops/make_tuple.h" | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t kInputDoubleNum = 2; | constexpr size_t kInputDoubleNum = 2; | ||||
| constexpr size_t kInputTripleNum = 3; | constexpr size_t kInputTripleNum = 3; | ||||
| void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) { | |||||
| MS_ASSERT(anf_node != nullptr); | |||||
| MS_ASSERT(inputs != nullptr); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return; | |||||
| int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||||
| auto first_input = cnode->input(1); | |||||
| auto second_input = cnode->input(2); | |||||
| AnfNodePtr must_monad = nullptr; | |||||
| AnfNodePtr not_must_monad = nullptr; | |||||
| if (utils::isa<ValueNode>(first_input)) { | |||||
| auto value_node = first_input->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(value_node->value() != nullptr); | |||||
| if (utils::isa<Monad>(value_node->value())) { | |||||
| must_monad = first_input; | |||||
| not_must_monad = second_input; | |||||
| } | |||||
| } | } | ||||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||||
| if (cnode->input(i)->isa<CNode>()) { | |||||
| inputs->push_back(cnode->input(i)); | |||||
| if (utils::isa<ValueNode>(second_input)) { | |||||
| auto value_node = second_input->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(value_node->value() != nullptr); | |||||
| if (utils::isa<Monad>(value_node->value())) { | |||||
| must_monad = second_input; | |||||
| not_must_monad = first_input; | |||||
| } | } | ||||
| } | } | ||||
| if (must_monad == nullptr) { | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) { | |||||
| manager->Replace(cnode, must_monad); | |||||
| } else { | |||||
| manager->Replace(cnode, not_must_monad); | |||||
| } | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) { | |||||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||||
| AnfNodePtr pre_node = cnode->input(1); | |||||
| AnfNodePtr post_node = cnode->input(2); | |||||
| if (!pre_node_is_first) { | |||||
| pre_node = cnode->input(2); | |||||
| post_node = cnode->input(1); | |||||
| } | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| auto node_users = manager->node_users()[pre_node]; | |||||
| auto iter = | |||||
| std::find_if(node_users.begin(), node_users.end(), | |||||
| [&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; }); | |||||
| if (iter == node_users.end()) { | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| auto tr = manager->Transact(); | |||||
| tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>())); | |||||
| tr.Commit(); | |||||
| auto depend_prim = std::make_shared<ops::Depend>(); | |||||
| auto depend_node = func_graph->NewCNode(depend_prim, {post_node, pre_node}); | |||||
| depend_node->set_fullname_with_scope(cnode->fullname_with_scope()); | |||||
| manager->Replace(cnode, depend_node); | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||||
| if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) { | |||||
| return lite::RET_OK; | |||||
| } | |||||
| if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) { | |||||
| return lite::RET_OK; | |||||
| } | |||||
| auto make_tuple_prim = NewValueNode(std::make_shared<ops::MakeTuple>()); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| manager->Replace(cnode->input(0), make_tuple_prim); | |||||
| return lite::RET_OK; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | ||||
| if (!utils::isa<CNodePtr>(anf_node)) { | if (!utils::isa<CNodePtr>(anf_node)) { | ||||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | MS_LOG(DEBUG) << "anf node is node a cnode."; | ||||
| @@ -73,28 +138,11 @@ int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, | |||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| auto cnode = anf_node->cast<CNodePtr>(); | auto cnode = anf_node->cast<CNodePtr>(); | ||||
| auto inputs = cnode->inputs(); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| if (!inputs[i]->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) { | |||||
| FetchCNodeFromMakeTuple(inputs[i], &new_inputs); | |||||
| continue; | |||||
| } | |||||
| new_inputs.push_back(inputs[i]); | |||||
| if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) { | |||||
| return lite::RET_OK; | |||||
| } | } | ||||
| for (auto &node : new_inputs) { | |||||
| func_graph->get_return()->add_input(node); | |||||
| } | |||||
| auto value = std::make_shared<UMonad>(); | |||||
| bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value)); | |||||
| if (!replace_succ) { | |||||
| MS_LOG(ERROR) << "replace redundant op failed."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| // both of two inputs are not monad, but have dependency. | |||||
| return ProcessInputHaveDependency(func_graph, cnode); | |||||
| } | } | ||||
| int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | ||||