Merge pull request !2791 from changzherui/subgraphtags/v0.6.0-beta
| @@ -28,6 +28,7 @@ | |||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "./common.h" | #include "./common.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace transform { | namespace transform { | ||||
| @@ -206,6 +207,7 @@ const char kNameRange[] = "Range"; | |||||
| const char kNameSquareSumAll[] = "SquareSumAll"; | const char kNameSquareSumAll[] = "SquareSumAll"; | ||||
| const char kNameAscendQuant[] = "AscendQuant"; | const char kNameAscendQuant[] = "AscendQuant"; | ||||
| const char kNameAscendDequant[] = "AscendDequant"; | const char kNameAscendDequant[] = "AscendDequant"; | ||||
| const char kNameCase[] = "Case"; | |||||
| // -----------------OpAdapter initialization-------------- | // -----------------OpAdapter initialization-------------- | ||||
| std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { | std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { | ||||
| @@ -413,7 +415,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameRange), ADPT_DESC(RangeD)}, | {string(kNameRange), ADPT_DESC(RangeD)}, | ||||
| {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, | {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, | ||||
| {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, | {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, | ||||
| {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}}; | |||||
| {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}, | |||||
| {string(kNameCase), ADPT_DESC(Case)}}; | |||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | ||||
| adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); | adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); | ||||
| @@ -435,13 +438,32 @@ PrimType GetCNodeFuncType(const CNodePtr cnode) { | |||||
| return kPrimTypeUnknown; | return kPrimTypeUnknown; | ||||
| } | } | ||||
| bool IsCaseNode(const CNodePtr node) { | |||||
| if (!node->inputs().empty() && node->input(0)->isa<CNode>() && | |||||
| GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::string GetCNodeTargetFuncName(const CNodePtr cnode) { | |||||
| if (IsCaseNode(cnode)) { | |||||
| return string(kNameCase); | |||||
| } | |||||
| auto name = GetCNodeFuncName(cnode); | |||||
| if (name == "switch_layer") { | |||||
| name = ""; | |||||
| } | |||||
| return name; | |||||
| } | |||||
| OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { | OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| std::string name = kNameCustomOp; | std::string name = kNameCustomOp; | ||||
| if (!IsCustomCNode(cnode)) { | if (!IsCustomCNode(cnode)) { | ||||
| name = GetCNodeFuncName(cnode); | |||||
| name = GetCNodeTargetFuncName(cnode); | |||||
| } | } | ||||
| auto it_adpt = get_adpt_map().find(name); | auto it_adpt = get_adpt_map().find(name); | ||||
| @@ -959,7 +981,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { | |||||
| auto c = anf_out->cast<CNodePtr>(); | auto c = anf_out->cast<CNodePtr>(); | ||||
| std::string name = ""; | std::string name = ""; | ||||
| if (anf_out->isa<CNode>()) { | if (anf_out->isa<CNode>()) { | ||||
| name = GetCNodeFuncName(c); | |||||
| name = GetCNodeTargetFuncName(c); | |||||
| } | } | ||||
| if (name == "make_tuple") { | if (name == "make_tuple") { | ||||
| @@ -1031,6 +1053,99 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) { | |||||
| return; | return; | ||||
| } | } | ||||
| void DfGraphConvertor::SetSubgraph(AnfNodePtr node) { | |||||
| if (!node->isa<CNode>()) { | |||||
| return; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!IsCaseNode(cnode)) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> case_inputs; | |||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||||
| case_inputs.emplace_back(cnode->input(i)); | |||||
| } | |||||
| std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>(); | |||||
| auto bnode = cnode->input(0)->cast<CNodePtr>()->input(2)->cast<CNodePtr>(); | |||||
| for (size_t i = 1; i < bnode->inputs().size(); i++) { | |||||
| auto branch_node = bnode->input(i)->cast<CNodePtr>(); | |||||
| for (size_t j = 2; j < branch_node->inputs().size(); j++) { | |||||
| if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { | |||||
| case_inputs.emplace_back(branch_node->input(j)); | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t i = 1; i < bnode->inputs().size(); i++) { | |||||
| ProcessSubgraph(bnode->input(i), case_inputs); | |||||
| } | |||||
| for (size_t i = 1; i < bnode->inputs().size(); i++) { | |||||
| branches->emplace_back(branches_map_[bnode->input(i).get()]); | |||||
| } | |||||
| if (op_cache_.find(node.get()) == op_cache_.end()) { | |||||
| return; | |||||
| } | |||||
| OpAdapterPtr adpt = FindAdapter(node, training_); | |||||
| if (nullptr == adpt) { | |||||
| MS_LOG(DEBUG) << "Not found adapter"; | |||||
| return; | |||||
| } | |||||
| OperatorPtr op = Convert(node); | |||||
| adpt->setSubgraph(op, 0, branches); | |||||
| return; | |||||
| } | |||||
| void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) { | |||||
| std::vector<AnfNodePtr> case_inputs; | |||||
| for (size_t i = 1; i < node->inputs().size(); i++) { | |||||
| case_inputs.emplace_back(node->input(i)); | |||||
| } | |||||
| std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>(); | |||||
| auto bnode = input_node->input(2)->cast<CNodePtr>(); | |||||
| for (size_t i = 1; i < bnode->inputs().size(); i++) { | |||||
| auto branch_node = bnode->input(i)->cast<CNodePtr>(); | |||||
| for (size_t j = 2; j < branch_node->inputs().size(); j++) { | |||||
| if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { | |||||
| case_inputs.emplace_back(branch_node->input(j)); | |||||
| } | |||||
| } | |||||
| } | |||||
| const size_t case_index = 1; | |||||
| const size_t make_tuple_index = 2; | |||||
| AnfNodePtr case_index_iter = input_node->input(case_index); | |||||
| AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index); | |||||
| auto make_tuple_node = make_tuple_iter->cast<CNodePtr>(); | |||||
| std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>(); | |||||
| for (size_t i = 0; i < case_inputs.size(); i++) { | |||||
| auto item = case_inputs[i]; | |||||
| auto op = Convert(item); | |||||
| if (op != nullptr) { | |||||
| tuple_items->emplace_back(OutHandler(op, "")); | |||||
| } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { | |||||
| tuple_items->push_back(out_handle_cache_[item.get()]); | |||||
| } else { | |||||
| MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items; | |||||
| std::shared_ptr<std::vector<AnfNodePtr>> case_input_items = std::make_shared<std::vector<AnfNodePtr>>(); | |||||
| case_input_items->emplace_back(case_index_iter); | |||||
| case_input_items->emplace_back(make_tuple_iter); | |||||
| case_input_handle_cache_[node.get()] = case_input_items; | |||||
| } | |||||
| DfGraphConvertor &DfGraphConvertor::BuildGraph() { | DfGraphConvertor &DfGraphConvertor::BuildGraph() { | ||||
| SetupDatasetIterGetNextNode(dataset_iter_getnext_); | SetupDatasetIterGetNextNode(dataset_iter_getnext_); | ||||
| @@ -1038,6 +1153,16 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Case node set input. | |||||
| std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return()); | |||||
| for (auto &it : nodes) { | |||||
| if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) { | |||||
| auto node = it->cast<CNodePtr>(); | |||||
| auto input_node = node->input(0)->cast<CNodePtr>(); | |||||
| GetCaseNodeInput(node, input_node); | |||||
| } | |||||
| } | |||||
| // update tuple_out_handle_cache_ | // update tuple_out_handle_cache_ | ||||
| for (auto it : tuple_out_handle_cache_) { | for (auto it : tuple_out_handle_cache_) { | ||||
| std::size_t len = it.second->size(); | std::size_t len = it.second->size(); | ||||
| @@ -1058,10 +1183,11 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { | |||||
| // set up dependices | // set up dependices | ||||
| MS_LOG(DEBUG) << "set up dependices"; | MS_LOG(DEBUG) << "set up dependices"; | ||||
| std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return()); | |||||
| nodes = ::mindspore::TopoSort(anf_graph_->get_return()); | |||||
| for (auto &it : nodes) { | for (auto &it : nodes) { | ||||
| SetNodeInput(it); | SetNodeInput(it); | ||||
| SetOpControlInput(it); | SetOpControlInput(it); | ||||
| SetSubgraph(it); | |||||
| UpdateOpDesc(it); | UpdateOpDesc(it); | ||||
| } | } | ||||
| @@ -1077,6 +1203,18 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { | |||||
| inputs.push_back(*dataset_iter_getnext_); | inputs.push_back(*dataset_iter_getnext_); | ||||
| } else { | } else { | ||||
| auto params = anf_graph_->parameters(); | auto params = anf_graph_->parameters(); | ||||
| if (use_inputs_) { | |||||
| params = inputs_; | |||||
| auto anf_params = anf_graph_->parameters(); | |||||
| for (size_t i = 0; i < params.size(); i++) { | |||||
| for (size_t j = 0; j < anf_params.size(); j++) { | |||||
| if (params[i]->ToString() == anf_params[j]->ToString()) { | |||||
| params[i] = anf_params[j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| int index = 0; | int index = 0; | ||||
| for (auto &it : params) { | for (auto &it : params) { | ||||
| auto name = std::static_pointer_cast<Parameter>(it)->name(); | auto name = std::static_pointer_cast<Parameter>(it)->name(); | ||||
| @@ -1187,10 +1325,21 @@ const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNa | |||||
| void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { | void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { | ||||
| OperatorPtr src = Convert(node); | OperatorPtr src = Convert(node); | ||||
| int case_flag = 0; | |||||
| auto &inputs = node->inputs(); | auto &inputs = node->inputs(); | ||||
| for (size_t i = 1; i < inputs.size(); i++) { | |||||
| size_t input_size = inputs.size(); | |||||
| if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) { | |||||
| case_flag = 1; | |||||
| input_size = case_input_handle_cache_[node.get()]->size() + 1; | |||||
| } | |||||
| for (size_t i = 1; i < input_size; i++) { | |||||
| auto pred = inputs[i]; | auto pred = inputs[i]; | ||||
| while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "Depend") { | |||||
| if (case_flag != 0) { | |||||
| pred = case_input_handle_cache_[node.get()]->at(i - 1); | |||||
| } | |||||
| while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "Depend") { | |||||
| pred = pred->cast<CNodePtr>()->input(1); | pred = pred->cast<CNodePtr>()->input(1); | ||||
| } | } | ||||
| // skip the None input | // skip the None input | ||||
| @@ -1198,7 +1347,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node | |||||
| continue; | continue; | ||||
| } | } | ||||
| // transform "Const" op to "Variable" op when the next node is "Assign" op. | // transform "Const" op to "Variable" op when the next node is "Assign" op. | ||||
| std::string c_name = GetCNodeFuncName(node); | |||||
| std::string c_name = GetCNodeTargetFuncName(node); | |||||
| auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); | auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); | ||||
| if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { | if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { | ||||
| std::string name = std::static_pointer_cast<Parameter>(pred)->name(); | std::string name = std::static_pointer_cast<Parameter>(pred)->name(); | ||||
| @@ -1222,7 +1371,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node | |||||
| if (it != out_handle_cache_.end()) { | if (it != out_handle_cache_.end()) { | ||||
| int ret = adpt->setInput(src, SizeToInt(i), it->second); | int ret = adpt->setInput(src, SizeToInt(i), it->second); | ||||
| if (ret == 0) { | if (ret == 0) { | ||||
| if (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") { | |||||
| if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") { | |||||
| compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] | compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] | ||||
| << ":" << i << endl; | << ":" << i << endl; | ||||
| } else if (pred->isa<Parameter>()) { | } else if (pred->isa<Parameter>()) { | ||||
| @@ -1280,6 +1429,23 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { | |||||
| DfGraphConvertor::SetOpInput(adpt, cnode); | DfGraphConvertor::SetOpInput(adpt, cnode); | ||||
| } | } | ||||
| void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") { | |||||
| return; | |||||
| } | |||||
| auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>(); | |||||
| FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>(); | |||||
| DfGraphConvertor convertor(anf_graph); | |||||
| convertor.use_inputs_ = true; | |||||
| convertor.inputs_ = inputs; | |||||
| (void)convertor.ConvertAllNode().BuildGraph(); | |||||
| std::string name = graph_node->ToString() + "_ge_graph.dot"; | |||||
| if (MsContext::GetInstance()->save_graphs_flag()) { | |||||
| convertor.DrawComputeGraph(name); | |||||
| } | |||||
| branches_map_[node.get()] = *(convertor.df_graph_); | |||||
| } | |||||
| // Update GE op's shape and type info | // Update GE op's shape and type info | ||||
| void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { | void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { | ||||
| if (nullptr == node || !node->isa<CNode>()) { | if (nullptr == node || !node->isa<CNode>()) { | ||||
| @@ -1350,6 +1516,7 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size(); | |||||
| tuple_out_handle_cache_[node.get()] = tuple_items; | tuple_out_handle_cache_[node.get()] = tuple_items; | ||||
| } | } | ||||
| @@ -1713,6 +1880,14 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (name == "" && GetCNodeFuncName(node) == "switch_layer") { | |||||
| return false; | |||||
| } | |||||
| if (name == "Partial") { | |||||
| return false; | |||||
| } | |||||
| // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers | // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers | ||||
| if (name == "make_tuple") { | if (name == "make_tuple") { | ||||
| ConvertMakeTuple(node); | ConvertMakeTuple(node); | ||||
| @@ -1734,7 +1909,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) | |||||
| } | } | ||||
| OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { | OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { | ||||
| std::string name = GetCNodeFuncName(node); | |||||
| std::string name = GetCNodeTargetFuncName(node); | |||||
| if (!CheckCNode(name, node)) { | if (!CheckCNode(name, node)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -1881,7 +2056,7 @@ void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { | |||||
| } | } | ||||
| compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString() | compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString() | ||||
| << ":" << GetCNodeFuncName(node) << "\"</td></tr>" << endl; | |||||
| << ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl; | |||||
| // print attrs' values | // print attrs' values | ||||
| auto atts = adpt->GetAttrsFromDrawGraph(); | auto atts = adpt->GetAttrsFromDrawGraph(); | ||||
| @@ -201,6 +201,7 @@ class DfGraphConvertor { | |||||
| OperatorPtr ConvertParameter(AnfNodePtr node); | OperatorPtr ConvertParameter(AnfNodePtr node); | ||||
| Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); | Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); | ||||
| OperatorPtr ConvertValueNode(ValueNodePtr node); | OperatorPtr ConvertValueNode(ValueNodePtr node); | ||||
| void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); | |||||
| void ConvertTupleGetItem(const CNodePtr node); | void ConvertTupleGetItem(const CNodePtr node); | ||||
| void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, | void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, | ||||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | ||||
| @@ -217,6 +218,8 @@ class DfGraphConvertor { | |||||
| void SetNodeInput(AnfNodePtr node); | void SetNodeInput(AnfNodePtr node); | ||||
| void SetOpControlInput(const AnfNodePtr node); | void SetOpControlInput(const AnfNodePtr node); | ||||
| void UpdateOpDesc(AnfNodePtr node); | void UpdateOpDesc(AnfNodePtr node); | ||||
| void SetSubgraph(AnfNodePtr node); | |||||
| void ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs); | |||||
| void BuildSaveCheckpointGraph(); | void BuildSaveCheckpointGraph(); | ||||
| void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); | void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); | ||||
| void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; | void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; | ||||
| @@ -228,22 +231,26 @@ class DfGraphConvertor { | |||||
| std::shared_ptr<DfGraph> save_ckp_graph_{nullptr}; | std::shared_ptr<DfGraph> save_ckp_graph_{nullptr}; | ||||
| std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr}; | std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr}; | ||||
| std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; | std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; | ||||
| std::unordered_map<AnfNode *, DfGraph> branches_map_; | |||||
| std::unordered_map<AnfNode *, OperatorPtr> op_cache_; | std::unordered_map<AnfNode *, OperatorPtr> op_cache_; | ||||
| std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_; | std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_; | ||||
| /* record "tuple_getitem"<->"out_handler" mapping */ | /* record "tuple_getitem"<->"out_handler" mapping */ | ||||
| std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; | std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; | ||||
| /* record "make_tuple"<->"out_handler vector" mapping */ | /* record "make_tuple"<->"out_handler vector" mapping */ | ||||
| std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_; | std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_; | ||||
| std::unordered_map<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> case_input_handle_cache_; | |||||
| std::unordered_map<std::string, AnfNodePtr> params_; | std::unordered_map<std::string, AnfNodePtr> params_; | ||||
| std::unordered_map<std::string, OperatorPtr> vars_; | std::unordered_map<std::string, OperatorPtr> vars_; | ||||
| std::vector<std::pair<ge::Operator, std::string>> graph_outputs_; | std::vector<std::pair<ge::Operator, std::string>> graph_outputs_; | ||||
| std::vector<OperatorPtr> graph_const_inputs_; | std::vector<OperatorPtr> graph_const_inputs_; | ||||
| std::vector<OperatorPtr> init_ops_; | std::vector<OperatorPtr> init_ops_; | ||||
| std::vector<OperatorPtr> broadcast_ops_; | std::vector<OperatorPtr> broadcast_ops_; | ||||
| std::vector<AnfNodePtr> inputs_; | |||||
| OperatorPtr dataset_iter_getnext_; | OperatorPtr dataset_iter_getnext_; | ||||
| Status error_ = SUCCESS; | Status error_ = SUCCESS; | ||||
| bool training_ = false; | bool training_ = false; | ||||
| bool distribute_ = false; | bool distribute_ = false; | ||||
| bool use_inputs_ = false; | |||||
| }; | }; | ||||
| } // namespace transform | } // namespace transform | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -164,6 +164,25 @@ class OpAdapter : public BaseOpAdapter { | |||||
| const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; } | const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; } | ||||
| const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; } | const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; } | ||||
| const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; } | const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; } | ||||
| const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; } | |||||
| Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) { | |||||
| MS_EXCEPTION_IF_NULL(op); | |||||
| auto it = dyn_subgraph_map_.find(index); | |||||
| if (it != dyn_subgraph_map_.end()) { | |||||
| auto size = branches->size(); | |||||
| it->second.create_dyn_subgraph(op, static_cast<unsigned int>(size)); | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| it->second.set_subgraph(op, static_cast<unsigned int>(i), std::make_shared<DfGraph>((*branches)[i])); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| return NOT_FOUND; | |||||
| } | |||||
| int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) override { | |||||
| return static_cast<int>(SetOpSubgraphFunc(op, index, branches)); | |||||
| } | |||||
| Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { | Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| @@ -855,6 +874,7 @@ class OpAdapter : public BaseOpAdapter { | |||||
| static const std::unordered_map<int, DynInputDesc> dyn_input_map_; | static const std::unordered_map<int, DynInputDesc> dyn_input_map_; | ||||
| static const std::unordered_map<int, OutputDesc> output_map_; | static const std::unordered_map<int, OutputDesc> output_map_; | ||||
| static const std::unordered_map<int, DynOutputDesc> dyn_output_map_; | static const std::unordered_map<int, DynOutputDesc> dyn_output_map_; | ||||
| static const std::unordered_map<int, DynSubGraphDesc> dyn_subgraph_map_; | |||||
| static const std::unordered_map<std::string, AttrDesc> attr_map_; | static const std::unordered_map<std::string, AttrDesc> attr_map_; | ||||
| static const std::unordered_map<std::string, int> enum_map_; | static const std::unordered_map<std::string, int> enum_map_; | ||||
| // convert input from anf graph to Attr in Operators | // convert input from anf graph to Attr in Operators | ||||
| @@ -874,6 +894,8 @@ const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_; | |||||
| template <typename T> | template <typename T> | ||||
| const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; | const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; | ||||
| template <typename T> | template <typename T> | ||||
| const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_; | |||||
| template <typename T> | |||||
| const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_; | const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_; | ||||
| template <typename T> | template <typename T> | ||||
| const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_; | const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_; | ||||
| @@ -88,6 +88,8 @@ using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr | |||||
| using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>; | using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>; | ||||
| using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>; | using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>; | ||||
| using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>; | using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>; | ||||
| using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>; | |||||
| using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>; | |||||
| struct AttrDesc { | struct AttrDesc { | ||||
| std::string name; | std::string name; | ||||
| @@ -108,6 +110,12 @@ struct DynInputDesc { | |||||
| DynInputHandleFunc set_handle; | DynInputHandleFunc set_handle; | ||||
| }; | }; | ||||
| struct DynSubGraphDesc { | |||||
| std::string name; | |||||
| CreateDynSubGraphFunc create_dyn_subgraph; | |||||
| DynSubGraphFunc set_subgraph; | |||||
| }; | |||||
| struct OutputDesc { | struct OutputDesc { | ||||
| std::string name; | std::string name; | ||||
| UpdateOutputDescFunc update_out_desc; | UpdateOutputDescFunc update_out_desc; | ||||
| @@ -123,6 +131,7 @@ class BaseOpAdapter { | |||||
| virtual ~BaseOpAdapter() {} | virtual ~BaseOpAdapter() {} | ||||
| virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; | virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; | ||||
| virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); } | virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); } | ||||
| virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) = 0; | |||||
| virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; | virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; | ||||
| virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; | virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; | ||||
| virtual int setInput(const OperatorPtr &op, int index, | virtual int setInput(const OperatorPtr &op, int index, | ||||
| @@ -146,6 +155,7 @@ class BaseOpAdapter { | |||||
| virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0; | virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0; | ||||
| virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0; | virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0; | ||||
| virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0; | virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0; | ||||
| virtual const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() = 0; | |||||
| void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } | void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } | ||||
| const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; } | const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; } | ||||
| void clearAttrVect() { attrs_vec_.clear(); } | void clearAttrVect() { attrs_vec_.clear(); } | ||||
| @@ -64,6 +64,22 @@ namespace transform { | |||||
| } \ | } \ | ||||
| } | } | ||||
| #define DYN_SUBGRAPH_MAP(T) \ | |||||
| template <> \ | |||||
| const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_ | |||||
| #define DYN_SUBGRAPH_DESC(name) \ | |||||
| { \ | |||||
| #name, \ | |||||
| [](const OperatorPtr op, unsigned int num) { \ | |||||
| auto p = std::static_pointer_cast<OpType>(op); \ | |||||
| (void)p->create_dynamic_subgraph_##name(num); \ | |||||
| }, \ | |||||
| [](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \ | |||||
| auto p = std::static_pointer_cast<OpType>(op); \ | |||||
| (void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \ | |||||
| } \ | |||||
| } | |||||
| #define ATTR_MAP(T) \ | #define ATTR_MAP(T) \ | ||||
| template <> \ | template <> \ | ||||
| const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_ | const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_ | ||||
| @@ -848,6 +864,13 @@ INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}}; | |||||
| ATTR_MAP(Cast) = EMPTY_ATTR_MAP; | ATTR_MAP(Cast) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; | ||||
| // Case | |||||
| INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}}; | |||||
| DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}}; | |||||
| ATTR_MAP(Case) = EMPTY_ATTR_MAP; | |||||
| DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}}; | |||||
| DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}}; | |||||
| // Reciprocal | // Reciprocal | ||||
| INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; | ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; | ||||
| @@ -46,6 +46,10 @@ namespace transform { | |||||
| template <> \ | template <> \ | ||||
| const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_; | const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_; | ||||
| #define DECLARE_OP_USE_DYN_SUBGRAPH(T) \ | |||||
| template <> \ | |||||
| const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_; | |||||
| #define DECLARE_OP_USE_DYN_OUTPUT(T) \ | #define DECLARE_OP_USE_DYN_OUTPUT(T) \ | ||||
| template <> \ | template <> \ | ||||
| const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; | const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; | ||||
| @@ -235,6 +239,10 @@ DECLARE_OP_USE_OUTPUT(RealDiv) | |||||
| DECLARE_OP_ADAPTER(Cast) | DECLARE_OP_ADAPTER(Cast) | ||||
| DECLARE_OP_USE_INPUT_ATTR(Cast) | DECLARE_OP_USE_INPUT_ATTR(Cast) | ||||
| DECLARE_OP_USE_OUTPUT(Cast) | DECLARE_OP_USE_OUTPUT(Cast) | ||||
| DECLARE_OP_ADAPTER(Case) | |||||
| DECLARE_OP_USE_DYN_INPUT(Case) | |||||
| DECLARE_OP_USE_DYN_SUBGRAPH(Case) | |||||
| DECLARE_OP_USE_DYN_OUTPUT(Case) | |||||
| DECLARE_OP_ADAPTER(Reciprocal) | DECLARE_OP_ADAPTER(Reciprocal) | ||||
| DECLARE_OP_USE_OUTPUT(Reciprocal) | DECLARE_OP_USE_OUTPUT(Reciprocal) | ||||
| DECLARE_OP_ADAPTER(Neg) | DECLARE_OP_ADAPTER(Neg) | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test case.""" | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, context | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.conv1 = nn.Conv2d(1, 3, 3) | |||||
| self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True) | |||||
| self.layers = (self.conv1, self.conv2) | |||||
| def construct(self, x, index): | |||||
| x = self.layers[index](x) | |||||
| y = self.conv1(x) | |||||
| return x + y | |||||
| def test_case(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| net = Net() | |||||
| data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32) | |||||
| idx = Tensor(1, mindspore.int32) | |||||
| net(data, idx) | |||||