|
|
|
@@ -28,6 +28,7 @@ |
|
|
|
#include "utils/config_manager.h" |
|
|
|
#include "utils/convert_utils.h" |
|
|
|
#include "./common.h" |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace transform { |
|
|
|
@@ -206,6 +207,7 @@ const char kNameRange[] = "Range"; |
|
|
|
const char kNameSquareSumAll[] = "SquareSumAll"; |
|
|
|
const char kNameAscendQuant[] = "AscendQuant"; |
|
|
|
const char kNameAscendDequant[] = "AscendDequant"; |
|
|
|
const char kNameCase[] = "Case"; |
|
|
|
|
|
|
|
// -----------------OpAdapter initialization-------------- |
|
|
|
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { |
|
|
|
@@ -413,7 +415,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma |
|
|
|
{string(kNameRange), ADPT_DESC(RangeD)}, |
|
|
|
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, |
|
|
|
{string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, |
|
|
|
{string(kNameAscendDequant), ADPT_DESC(AscendDequant)}}; |
|
|
|
{string(kNameAscendDequant), ADPT_DESC(AscendDequant)}, |
|
|
|
{string(kNameCase), ADPT_DESC(Case)}}; |
|
|
|
#ifdef ENABLE_GE |
|
|
|
adpt_map[string(kNamePrint)] = ADPT_DESC(Print); |
|
|
|
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); |
|
|
|
@@ -435,13 +438,32 @@ PrimType GetCNodeFuncType(const CNodePtr cnode) { |
|
|
|
return kPrimTypeUnknown; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCaseNode(const CNodePtr node) { |
|
|
|
if (!node->inputs().empty() && node->input(0)->isa<CNode>() && |
|
|
|
GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetCNodeTargetFuncName(const CNodePtr cnode) { |
|
|
|
if (IsCaseNode(cnode)) { |
|
|
|
return string(kNameCase); |
|
|
|
} |
|
|
|
auto name = GetCNodeFuncName(cnode); |
|
|
|
if (name == "switch_layer") { |
|
|
|
name = ""; |
|
|
|
} |
|
|
|
return name; |
|
|
|
} |
|
|
|
|
|
|
|
OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
|
|
std::string name = kNameCustomOp; |
|
|
|
if (!IsCustomCNode(cnode)) { |
|
|
|
name = GetCNodeFuncName(cnode); |
|
|
|
name = GetCNodeTargetFuncName(cnode); |
|
|
|
} |
|
|
|
|
|
|
|
auto it_adpt = get_adpt_map().find(name); |
|
|
|
@@ -959,7 +981,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { |
|
|
|
auto c = anf_out->cast<CNodePtr>(); |
|
|
|
std::string name = ""; |
|
|
|
if (anf_out->isa<CNode>()) { |
|
|
|
name = GetCNodeFuncName(c); |
|
|
|
name = GetCNodeTargetFuncName(c); |
|
|
|
} |
|
|
|
|
|
|
|
if (name == "make_tuple") { |
|
|
|
@@ -1031,6 +1053,99 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void DfGraphConvertor::SetSubgraph(AnfNodePtr node) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (!IsCaseNode(cnode)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> case_inputs; |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) { |
|
|
|
case_inputs.emplace_back(cnode->input(i)); |
|
|
|
} |
|
|
|
std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>(); |
|
|
|
auto bnode = cnode->input(0)->cast<CNodePtr>()->input(2)->cast<CNodePtr>(); |
|
|
|
|
|
|
|
for (size_t i = 1; i < bnode->inputs().size(); i++) { |
|
|
|
auto branch_node = bnode->input(i)->cast<CNodePtr>(); |
|
|
|
for (size_t j = 2; j < branch_node->inputs().size(); j++) { |
|
|
|
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { |
|
|
|
case_inputs.emplace_back(branch_node->input(j)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 1; i < bnode->inputs().size(); i++) { |
|
|
|
ProcessSubgraph(bnode->input(i), case_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 1; i < bnode->inputs().size(); i++) { |
|
|
|
branches->emplace_back(branches_map_[bnode->input(i).get()]); |
|
|
|
} |
|
|
|
|
|
|
|
if (op_cache_.find(node.get()) == op_cache_.end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
OpAdapterPtr adpt = FindAdapter(node, training_); |
|
|
|
if (nullptr == adpt) { |
|
|
|
MS_LOG(DEBUG) << "Not found adapter"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
OperatorPtr op = Convert(node); |
|
|
|
adpt->setSubgraph(op, 0, branches); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) { |
|
|
|
std::vector<AnfNodePtr> case_inputs; |
|
|
|
for (size_t i = 1; i < node->inputs().size(); i++) { |
|
|
|
case_inputs.emplace_back(node->input(i)); |
|
|
|
} |
|
|
|
std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>(); |
|
|
|
auto bnode = input_node->input(2)->cast<CNodePtr>(); |
|
|
|
|
|
|
|
for (size_t i = 1; i < bnode->inputs().size(); i++) { |
|
|
|
auto branch_node = bnode->input(i)->cast<CNodePtr>(); |
|
|
|
for (size_t j = 2; j < branch_node->inputs().size(); j++) { |
|
|
|
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { |
|
|
|
case_inputs.emplace_back(branch_node->input(j)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
const size_t case_index = 1; |
|
|
|
const size_t make_tuple_index = 2; |
|
|
|
|
|
|
|
AnfNodePtr case_index_iter = input_node->input(case_index); |
|
|
|
AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index); |
|
|
|
auto make_tuple_node = make_tuple_iter->cast<CNodePtr>(); |
|
|
|
std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>(); |
|
|
|
|
|
|
|
for (size_t i = 0; i < case_inputs.size(); i++) { |
|
|
|
auto item = case_inputs[i]; |
|
|
|
auto op = Convert(item); |
|
|
|
if (op != nullptr) { |
|
|
|
tuple_items->emplace_back(OutHandler(op, "")); |
|
|
|
} else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { |
|
|
|
tuple_items->push_back(out_handle_cache_[item.get()]); |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items; |
|
|
|
|
|
|
|
std::shared_ptr<std::vector<AnfNodePtr>> case_input_items = std::make_shared<std::vector<AnfNodePtr>>(); |
|
|
|
case_input_items->emplace_back(case_index_iter); |
|
|
|
case_input_items->emplace_back(make_tuple_iter); |
|
|
|
case_input_handle_cache_[node.get()] = case_input_items; |
|
|
|
} |
|
|
|
|
|
|
|
DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
SetupDatasetIterGetNextNode(dataset_iter_getnext_); |
|
|
|
|
|
|
|
@@ -1038,6 +1153,16 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
// Case node set input. |
|
|
|
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return()); |
|
|
|
for (auto &it : nodes) { |
|
|
|
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) { |
|
|
|
auto node = it->cast<CNodePtr>(); |
|
|
|
auto input_node = node->input(0)->cast<CNodePtr>(); |
|
|
|
GetCaseNodeInput(node, input_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// update tuple_out_handle_cache_ |
|
|
|
for (auto it : tuple_out_handle_cache_) { |
|
|
|
std::size_t len = it.second->size(); |
|
|
|
@@ -1058,10 +1183,11 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
|
|
|
|
// set up dependices |
|
|
|
MS_LOG(DEBUG) << "set up dependices"; |
|
|
|
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return()); |
|
|
|
nodes = ::mindspore::TopoSort(anf_graph_->get_return()); |
|
|
|
for (auto &it : nodes) { |
|
|
|
SetNodeInput(it); |
|
|
|
SetOpControlInput(it); |
|
|
|
SetSubgraph(it); |
|
|
|
UpdateOpDesc(it); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1077,6 +1203,18 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
inputs.push_back(*dataset_iter_getnext_); |
|
|
|
} else { |
|
|
|
auto params = anf_graph_->parameters(); |
|
|
|
if (use_inputs_) { |
|
|
|
params = inputs_; |
|
|
|
auto anf_params = anf_graph_->parameters(); |
|
|
|
for (size_t i = 0; i < params.size(); i++) { |
|
|
|
for (size_t j = 0; j < anf_params.size(); j++) { |
|
|
|
if (params[i]->ToString() == anf_params[j]->ToString()) { |
|
|
|
params[i] = anf_params[j]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int index = 0; |
|
|
|
for (auto &it : params) { |
|
|
|
auto name = std::static_pointer_cast<Parameter>(it)->name(); |
|
|
|
@@ -1187,10 +1325,21 @@ const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNa |
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { |
|
|
|
OperatorPtr src = Convert(node); |
|
|
|
int case_flag = 0; |
|
|
|
auto &inputs = node->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); i++) { |
|
|
|
size_t input_size = inputs.size(); |
|
|
|
if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) { |
|
|
|
case_flag = 1; |
|
|
|
input_size = case_input_handle_cache_[node.get()]->size() + 1; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 1; i < input_size; i++) { |
|
|
|
auto pred = inputs[i]; |
|
|
|
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "Depend") { |
|
|
|
if (case_flag != 0) { |
|
|
|
pred = case_input_handle_cache_[node.get()]->at(i - 1); |
|
|
|
} |
|
|
|
|
|
|
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "Depend") { |
|
|
|
pred = pred->cast<CNodePtr>()->input(1); |
|
|
|
} |
|
|
|
// skip the None input |
|
|
|
@@ -1198,7 +1347,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node |
|
|
|
continue; |
|
|
|
} |
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op. |
|
|
|
std::string c_name = GetCNodeFuncName(node); |
|
|
|
std::string c_name = GetCNodeTargetFuncName(node); |
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); |
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { |
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name(); |
|
|
|
@@ -1222,7 +1371,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node |
|
|
|
if (it != out_handle_cache_.end()) { |
|
|
|
int ret = adpt->setInput(src, SizeToInt(i), it->second); |
|
|
|
if (ret == 0) { |
|
|
|
if (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") { |
|
|
|
if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") { |
|
|
|
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] |
|
|
|
<< ":" << i << endl; |
|
|
|
} else if (pred->isa<Parameter>()) { |
|
|
|
@@ -1280,6 +1429,23 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { |
|
|
|
DfGraphConvertor::SetOpInput(adpt, cnode); |
|
|
|
} |
|
|
|
|
|
|
|
void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) { |
|
|
|
if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>(); |
|
|
|
FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>(); |
|
|
|
DfGraphConvertor convertor(anf_graph); |
|
|
|
convertor.use_inputs_ = true; |
|
|
|
convertor.inputs_ = inputs; |
|
|
|
(void)convertor.ConvertAllNode().BuildGraph(); |
|
|
|
std::string name = graph_node->ToString() + "_ge_graph.dot"; |
|
|
|
if (MsContext::GetInstance()->save_graphs_flag()) { |
|
|
|
convertor.DrawComputeGraph(name); |
|
|
|
} |
|
|
|
branches_map_[node.get()] = *(convertor.df_graph_); |
|
|
|
} |
|
|
|
|
|
|
|
// Update GE op's shape and type info |
|
|
|
void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { |
|
|
|
if (nullptr == node || !node->isa<CNode>()) { |
|
|
|
@@ -1350,6 +1516,7 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size(); |
|
|
|
tuple_out_handle_cache_[node.get()] = tuple_items; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1713,6 +1880,14 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (name == "" && GetCNodeFuncName(node) == "switch_layer") { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (name == "Partial") { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers |
|
|
|
if (name == "make_tuple") { |
|
|
|
ConvertMakeTuple(node); |
|
|
|
@@ -1734,7 +1909,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) |
|
|
|
} |
|
|
|
|
|
|
|
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { |
|
|
|
std::string name = GetCNodeFuncName(node); |
|
|
|
std::string name = GetCNodeTargetFuncName(node); |
|
|
|
if (!CheckCNode(name, node)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -1881,7 +2056,7 @@ void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { |
|
|
|
} |
|
|
|
|
|
|
|
compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString() |
|
|
|
<< ":" << GetCNodeFuncName(node) << "\"</td></tr>" << endl; |
|
|
|
<< ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl; |
|
|
|
|
|
|
|
// print attrs' values |
|
|
|
auto atts = adpt->GetAttrsFromDrawGraph(); |
|
|
|
|