|
|
@@ -15,6 +15,9 @@ |
|
|
*/ |
|
|
*/ |
|
|
#include "session/ascend_session.h" |
|
|
#include "session/ascend_session.h" |
|
|
#include <algorithm> |
|
|
#include <algorithm> |
|
|
|
|
|
#include <map> |
|
|
|
|
|
#include <tuple> |
|
|
|
|
|
#include <set> |
|
|
#include "operator/ops.h" |
|
|
#include "operator/ops.h" |
|
|
#include "ir/meta_tensor.h" |
|
|
#include "ir/meta_tensor.h" |
|
|
#include "ir/anf.h" |
|
|
#include "ir/anf.h" |
|
|
@@ -75,28 +78,15 @@ void DumpGraphInputArgs(const VectorRef &args) { |
|
|
|
|
|
|
|
|
void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { |
|
|
void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
for (auto &node : graph->execution_order()) { |
|
|
|
|
|
if (is_override || AnfAlgo::GetStreamDistinctionLabel(node.get()) == kInvalidDistincLabel) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
AnfAlgo::SetStreamDistinctionLabel(label, node.get()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GraphId GetDistinctionLabel(const KernelGraphPtr &graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
// if graph is empty,use graph id as distinction label |
|
|
|
|
|
if (graph->execution_order().empty()) { |
|
|
|
|
|
return graph->graph_id(); |
|
|
|
|
|
|
|
|
if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { |
|
|
|
|
|
graph->set_stream_distinction_label(label); |
|
|
} |
|
|
} |
|
|
// else use first node of execution order as label |
|
|
|
|
|
return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get()); |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { |
|
|
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
std::vector<AnfNodePtr> graph_inputs = graph->inputs(); |
|
|
std::vector<AnfNodePtr> graph_inputs = graph->inputs(); |
|
|
auto valid_inputs = graph->ValidInputs(); |
|
|
|
|
|
|
|
|
auto valid_inputs = graph->valid_inputs(); |
|
|
size_t real_args_size = 0; |
|
|
size_t real_args_size = 0; |
|
|
std::vector<BaseRef> real_args = {}; |
|
|
std::vector<BaseRef> real_args = {}; |
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
@@ -141,23 +131,9 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar |
|
|
|
|
|
|
|
|
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
MS_LOG(INFO) << "start"; |
|
|
MS_LOG(INFO) << "start"; |
|
|
auto graph_id = graph_sum_; |
|
|
|
|
|
// construct graph, if successfully, graph_sum_ + 1 |
|
|
// construct graph, if successfully, graph_sum_ + 1 |
|
|
auto graph = ConstructKernelGraph(lst, outputs); |
|
|
auto graph = ConstructKernelGraph(lst, outputs); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
opt::AscendBackendIRFusionOptimization(graph); |
|
|
|
|
|
// select kernel build info |
|
|
|
|
|
SelectKernel(*graph); |
|
|
|
|
|
// convert kernel Graph to model |
|
|
|
|
|
predictmodel::StepConvertGraph(graph); |
|
|
|
|
|
// optimize graph |
|
|
|
|
|
HardwareOptimize(graph); |
|
|
|
|
|
// init runtime resource |
|
|
|
|
|
InitRuntimeResource(); |
|
|
|
|
|
// assign static memory of parameters |
|
|
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance); |
|
|
|
|
|
runtime_instance->AssignStaticMemoryInput(graph.get()); |
|
|
|
|
|
|
|
|
auto graph_id = graph->graph_id(); |
|
|
MS_LOG(INFO) << "Compile graph " << graph_id << " success"; |
|
|
MS_LOG(INFO) << "Compile graph " << graph_id << " success"; |
|
|
return graph_id; |
|
|
return graph_id; |
|
|
} |
|
|
} |
|
|
@@ -166,16 +142,36 @@ void AscendSession::BuildGraph(GraphId graph_id) { |
|
|
MS_LOG(INFO) << "start"; |
|
|
MS_LOG(INFO) << "start"; |
|
|
auto graph = GetGraph(graph_id); |
|
|
auto graph = GetGraph(graph_id); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
// resource initialize |
|
|
|
|
|
InitRuntimeResource(); |
|
|
// multiple graph handle |
|
|
// multiple graph handle |
|
|
if (graph_id == final_graph_id_) { |
|
|
if (graph_id == final_graph_id_) { |
|
|
if (!graph->executable()) { |
|
|
if (!graph->executable()) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
// insert assigns to child graph |
|
|
|
|
|
InsertAllAssigns(); |
|
|
|
|
|
// insert switch and active to child graph |
|
|
|
|
|
MergeSwitchCompile(); |
|
|
|
|
|
// OptChildGraphs |
|
|
|
|
|
auto graph_order = GetGraphOrder(final_graph_id_); |
|
|
|
|
|
auto &graph_type = GetGraphOrderType(final_graph_id_); |
|
|
|
|
|
for (size_t i = 0; i < graph_order.size(); i++) { |
|
|
|
|
|
if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
MS_LOG(INFO) << "Start build child graph " << graph_order[i]; |
|
|
|
|
|
auto child_graph = GetGraph(graph_order[i]); |
|
|
|
|
|
CompileChildGraph(child_graph); |
|
|
|
|
|
} |
|
|
// merge child graph |
|
|
// merge child graph |
|
|
MergeGraphExecOrder(); |
|
|
MergeGraphExecOrder(); |
|
|
} else { |
|
|
} else { |
|
|
|
|
|
auto single_graph = GetGraph(graph_id); |
|
|
|
|
|
CompileChildGraph(single_graph); |
|
|
// set the distinction label of single graph |
|
|
// set the distinction label of single graph |
|
|
SetStreamDistinctionLabel(GetGraph(graph_id), graph_id, false); |
|
|
|
|
|
|
|
|
single_graph->set_stream_distinction_label(graph_id); |
|
|
|
|
|
single_graph->UpdateExecuteKernelStreamLabel(); |
|
|
} |
|
|
} |
|
|
// adjust execution order because merge child graph and other special operations |
|
|
// adjust execution order because merge child graph and other special operations |
|
|
AdjustKernel(graph); |
|
|
AdjustKernel(graph); |
|
|
@@ -197,9 +193,26 @@ void AscendSession::BuildGraph(GraphId graph_id) { |
|
|
// load task info to device if it is sink mode |
|
|
// load task info to device if it is sink mode |
|
|
LoadTask(graph); |
|
|
LoadTask(graph); |
|
|
} |
|
|
} |
|
|
|
|
|
// sync the inital const tensor to device |
|
|
|
|
|
SyncInitialTenosrToDevice(); |
|
|
MS_LOG(INFO) << "end"; |
|
|
MS_LOG(INFO) << "end"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph); |
|
|
|
|
|
opt::AscendBackendIRFusionOptimization(child_graph); |
|
|
|
|
|
// select kernel build info |
|
|
|
|
|
SelectKernel(*child_graph); |
|
|
|
|
|
// convert kernel Graph to model |
|
|
|
|
|
predictmodel::StepConvertGraph(child_graph); |
|
|
|
|
|
// optimize graph |
|
|
|
|
|
HardwareOptimize(child_graph); |
|
|
|
|
|
// assign static memory of parameters |
|
|
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance); |
|
|
|
|
|
runtime_instance->AssignStaticMemoryInput(child_graph.get()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, |
|
|
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, |
|
|
VectorRef *const outputs) { |
|
|
VectorRef *const outputs) { |
|
|
MS_LOG(INFO) << "start"; |
|
|
MS_LOG(INFO) << "start"; |
|
|
@@ -458,11 +471,9 @@ void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const |
|
|
|
|
|
|
|
|
GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) { |
|
|
GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) { |
|
|
MS_LOG(INFO) << "Start! Args size " << args.size(); |
|
|
MS_LOG(INFO) << "Start! Args size " << args.size(); |
|
|
auto final_graph = std::make_shared<KernelGraph>(); |
|
|
|
|
|
final_graph_id_ = graph_sum_++; |
|
|
|
|
|
graphs_[final_graph_id_] = final_graph; |
|
|
|
|
|
final_graph->set_graph_id(final_graph_id_); |
|
|
|
|
|
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << "success"; |
|
|
|
|
|
|
|
|
auto final_graph = NewKernelGraph(); |
|
|
|
|
|
final_graph_id_ = final_graph->graph_id(); |
|
|
|
|
|
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; |
|
|
// init private variables and bind them with final_graph_id |
|
|
// init private variables and bind them with final_graph_id |
|
|
graph_execute_orders_[final_graph_id_] = std::vector<GraphId>(); |
|
|
graph_execute_orders_[final_graph_id_] = std::vector<GraphId>(); |
|
|
graph_order_types_[final_graph_id_] = std::vector<GraphType>(); |
|
|
graph_order_types_[final_graph_id_] = std::vector<GraphType>(); |
|
|
@@ -498,6 +509,46 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) { |
|
|
return final_graph_id_; |
|
|
return final_graph_id_; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { |
|
|
|
|
|
auto fake_graph = GetGraph(fake_graph_id); |
|
|
|
|
|
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); |
|
|
|
|
|
auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { |
|
|
|
|
|
auto parameter = fake_graph->NewParameter(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
|
|
parameter->set_abstract(abstract); |
|
|
|
|
|
auto new_parameter = fake_graph->NewParameter(parameter); |
|
|
|
|
|
// Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. |
|
|
|
|
|
auto graph_inputs = fake_graph->MutableInputs(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
|
|
return new_parameter; |
|
|
|
|
|
}; |
|
|
|
|
|
auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
auto abstract = cnode->abstract(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
|
|
// create multiple parameters if is a tuple output real kernel |
|
|
|
|
|
if (abstract->isa<abstract::AbstractTuple>()) { |
|
|
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract); |
|
|
|
|
|
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]"; |
|
|
|
|
|
return create_parameter((*tuple_abstract)[output_idx]); |
|
|
|
|
|
} |
|
|
|
|
|
return create_parameter(cnode->abstract()); |
|
|
|
|
|
}; |
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { |
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
|
|
auto make_tuple = output_item_with_index.first->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
|
|
for (size_t i = 1; i < make_tuple->inputs().size(); i++) { |
|
|
|
|
|
auto input = make_tuple->inputs()[i]; |
|
|
|
|
|
make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); |
|
|
|
|
|
} |
|
|
|
|
|
return fake_graph->NewCNode(make_tuple_inputs); |
|
|
|
|
|
} |
|
|
|
|
|
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void AscendSession::SetFinalGraphOutput(const BaseRef &output) { |
|
|
void AscendSession::SetFinalGraphOutput(const BaseRef &output) { |
|
|
auto final_graph = GetGraph(final_graph_id_); |
|
|
auto final_graph = GetGraph(final_graph_id_); |
|
|
MS_EXCEPTION_IF_NULL(final_graph); |
|
|
MS_EXCEPTION_IF_NULL(final_graph); |
|
|
@@ -559,12 +610,6 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true |
|
|
condition_graph->AddValueNodeToGraph(counter_const); |
|
|
condition_graph->AddValueNodeToGraph(counter_const); |
|
|
// create a new switch op |
|
|
// create a new switch op |
|
|
auto switch_primitive = std::make_shared<Primitive>("StreamSwitch"); |
|
|
auto switch_primitive = std::make_shared<Primitive>("StreamSwitch"); |
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
|
|
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); |
|
|
|
|
|
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeInt32}); |
|
|
|
|
|
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); |
|
|
|
|
|
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); |
|
|
|
|
|
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); |
|
|
|
|
|
auto cond_output_it = condition_output_.find(condition_graph_id); |
|
|
auto cond_output_it = condition_output_.find(condition_graph_id); |
|
|
if (cond_output_it == condition_output_.end()) { |
|
|
if (cond_output_it == condition_output_.end()) { |
|
|
MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; |
|
|
MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; |
|
|
@@ -574,11 +619,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true |
|
|
MS_EXCEPTION_IF_NULL(cond_output_kernel); |
|
|
MS_EXCEPTION_IF_NULL(cond_output_kernel); |
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; |
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; |
|
|
CNodePtr switch_node = condition_graph->NewCNode(inputs); |
|
|
CNodePtr switch_node = condition_graph->NewCNode(inputs); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_node); |
|
|
MS_EXCEPTION_IF_NULL(switch_node); |
|
|
switch_node->set_abstract(std::make_shared<abstract::AbstractNone>()); |
|
|
switch_node->set_abstract(std::make_shared<abstract::AbstractNone>()); |
|
|
AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); |
|
|
AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); |
|
|
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(GetGraph(condition_graph_id)), switch_node.get()); |
|
|
|
|
|
// set attr: cond_ RT_GREATER |
|
|
// set attr: cond_ RT_GREATER |
|
|
AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue<int>(static_cast<int>(RT_GREATER)), switch_node); |
|
|
AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue<int>(static_cast<int>(RT_GREATER)), switch_node); |
|
|
// set attr:data_type |
|
|
// set attr:data_type |
|
|
@@ -586,9 +629,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true |
|
|
// set attr:true branch graph id ,which is same to stream distinction label |
|
|
// set attr:true branch graph id ,which is same to stream distinction label |
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(true_graph_id), switch_node); |
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(true_graph_id), switch_node); |
|
|
// append switch at the end of condition graph |
|
|
// append switch at the end of condition graph |
|
|
std::vector<CNodePtr> exec_order = condition_graph->execution_order(); |
|
|
|
|
|
exec_order.push_back(switch_node); |
|
|
|
|
|
condition_graph->set_execution_order(exec_order); |
|
|
|
|
|
|
|
|
auto return_node = condition_graph->get_return(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
|
|
InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node); |
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -615,8 +658,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { |
|
|
MS_EXCEPTION_IF_NULL(true_last); |
|
|
MS_EXCEPTION_IF_NULL(true_last); |
|
|
MS_EXCEPTION_IF_NULL(false_last); |
|
|
MS_EXCEPTION_IF_NULL(false_last); |
|
|
MS_LOG(INFO) << "The last graph of false branch is " << false_last_id; |
|
|
MS_LOG(INFO) << "The last graph of false branch is " << false_last_id; |
|
|
// now only consider the single output |
|
|
|
|
|
InsertMultipleAssignToGraph(true_last_id, true_last->output(), false_last->output()); |
|
|
|
|
|
|
|
|
// create fake output |
|
|
|
|
|
auto fake_output_graph = NewKernelGraph(); |
|
|
|
|
|
graph_execute_order.push_back(fake_output_graph->graph_id()); |
|
|
|
|
|
graph_order_type.push_back(COMMON_GRAPH); |
|
|
|
|
|
fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); |
|
|
|
|
|
final_graph->set_output(fake_output_graph->output()); |
|
|
|
|
|
InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); |
|
|
|
|
|
InsertMultipleAssignToGraph(false_last_id, false_last->output(), final_graph->output()); |
|
|
// insert stream active for loop sink |
|
|
// insert stream active for loop sink |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
@@ -650,14 +699,14 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, |
|
|
if (false_graph_id != kInvalidGraphId) { |
|
|
if (false_graph_id != kInvalidGraphId) { |
|
|
// false graph and condition in graph same stream |
|
|
// false graph and condition in graph same stream |
|
|
auto condition_graph = GetGraph(cond_graph_id); |
|
|
auto condition_graph = GetGraph(cond_graph_id); |
|
|
SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true); |
|
|
|
|
|
|
|
|
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); |
|
|
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again |
|
|
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again |
|
|
auto cond_it = switches_.find(false_graph_id); |
|
|
auto cond_it = switches_.find(false_graph_id); |
|
|
while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { |
|
|
while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { |
|
|
cond_graph_id = cond_it->first; |
|
|
cond_graph_id = cond_it->first; |
|
|
false_graph_id = cond_it->second.second; |
|
|
false_graph_id = cond_it->second.second; |
|
|
condition_graph = GetGraph(cond_graph_id); |
|
|
condition_graph = GetGraph(cond_graph_id); |
|
|
SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true); |
|
|
|
|
|
|
|
|
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); |
|
|
cond_it = switches_.find(false_graph_id); |
|
|
cond_it = switches_.find(false_graph_id); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -691,7 +740,7 @@ void AscendSession::MergeSwitchCompile() { |
|
|
} |
|
|
} |
|
|
// insert stream active to common graph |
|
|
// insert stream active to common graph |
|
|
if (prev_graph_id != kInvalidGraphId) { |
|
|
if (prev_graph_id != kInvalidGraphId) { |
|
|
InsertStreamActiveToGraph(prev_graph_id, GetDistinctionLabel(condition_graph)); |
|
|
|
|
|
|
|
|
InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); |
|
|
} |
|
|
} |
|
|
// if this is a 'if' condition |
|
|
// if this is a 'if' condition |
|
|
auto it = while_condition_graphs_.find(cond_graph_id); |
|
|
auto it = while_condition_graphs_.find(cond_graph_id); |
|
|
@@ -700,12 +749,39 @@ void AscendSession::MergeSwitchCompile() { |
|
|
} else { |
|
|
} else { |
|
|
// if it is a while,insert a stream active to true graph |
|
|
// if it is a while,insert a stream active to true graph |
|
|
GraphId from_graph = it->second; |
|
|
GraphId from_graph = it->second; |
|
|
InsertStreamActiveToGraph(from_graph, GetDistinctionLabel(condition_graph)); |
|
|
|
|
|
|
|
|
InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::InsertAllAssigns() { |
|
|
|
|
|
std::set<std::pair<AnfNodePtr, AnfNodePtr>> assigns; |
|
|
|
|
|
for (auto assign : assigns_) { |
|
|
|
|
|
auto front_anf = std::get<0>(assign); |
|
|
|
|
|
auto to_graph_id = std::get<1>(assign); |
|
|
|
|
|
auto input_idx = std::get<2>(assign); |
|
|
|
|
|
auto to_graph = GetGraph(to_graph_id); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(to_graph); |
|
|
|
|
|
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); |
|
|
|
|
|
if (input_idx >= graph_inputs.size()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size(); |
|
|
|
|
|
} |
|
|
|
|
|
auto backend_parameter = graph_inputs[input_idx]; |
|
|
|
|
|
(void)assigns.insert(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter)); |
|
|
|
|
|
} |
|
|
|
|
|
// erase the repeat assign |
|
|
|
|
|
for (auto &assign : assigns) { |
|
|
|
|
|
auto front_anf = assign.first; |
|
|
|
|
|
auto backend_parameter = assign.second; |
|
|
|
|
|
auto from_graph_id = GetGraphIdByNode(front_anf); |
|
|
|
|
|
auto from_graph = GetGraph(from_graph_id); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(from_graph); |
|
|
|
|
|
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); |
|
|
|
|
|
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// insert active to graph |
|
|
// insert active to graph |
|
|
void AscendSession::SetActive(GraphId from, GraphId to) { |
|
|
void AscendSession::SetActive(GraphId from, GraphId to) { |
|
|
if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { |
|
|
if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { |
|
|
@@ -735,20 +811,21 @@ void AscendSession::SetActive(GraphId from, GraphId to) { |
|
|
while_condition_graphs_[to] = from; |
|
|
while_condition_graphs_[to] = from; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter) { |
|
|
|
|
|
|
|
|
void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { |
|
|
MS_LOG(INFO) << "Start!"; |
|
|
MS_LOG(INFO) << "Start!"; |
|
|
MS_EXCEPTION_IF_NULL(backend_parameter); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(front_anf); |
|
|
MS_EXCEPTION_IF_NULL(front_anf); |
|
|
if (!backend_parameter->isa<Parameter>()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Backend parameter's type is not a parameter,but is " << backend_parameter->ToString(); |
|
|
|
|
|
} |
|
|
|
|
|
auto from_graph_id = GetGraphIdByNode(front_anf); |
|
|
auto from_graph_id = GetGraphIdByNode(front_anf); |
|
|
auto from_graph = GetGraph(from_graph_id); |
|
|
auto from_graph = GetGraph(from_graph_id); |
|
|
MS_EXCEPTION_IF_NULL(from_graph); |
|
|
MS_EXCEPTION_IF_NULL(from_graph); |
|
|
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); |
|
|
|
|
|
auto to_graph = GetGraph(to_graph_id); |
|
|
auto to_graph = GetGraph(to_graph_id); |
|
|
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(to_graph); |
|
|
MS_EXCEPTION_IF_NULL(to_graph); |
|
|
|
|
|
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); |
|
|
|
|
|
if (input_idx >= graph_inputs.size()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size(); |
|
|
|
|
|
} |
|
|
|
|
|
auto backend_parameter = graph_inputs[input_idx]; |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(backend_parameter); |
|
|
|
|
|
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); |
|
|
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" |
|
|
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" |
|
|
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) |
|
|
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) |
|
|
<< "]"; |
|
|
<< "]"; |
|
|
@@ -759,39 +836,21 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An |
|
|
// if arg is the the parameter of child graph,it is parameter of final graph too |
|
|
// if arg is the the parameter of child graph,it is parameter of final graph too |
|
|
if (front_anf->isa<Parameter>()) { |
|
|
if (front_anf->isa<Parameter>()) { |
|
|
MS_EXCEPTION_IF_NULL(backend_arg); |
|
|
MS_EXCEPTION_IF_NULL(backend_arg); |
|
|
if (!AnfAlgo::OutputAddrExist(backend_arg, 0)) { |
|
|
|
|
|
// set parameter's addr in child graph to parameter in final graph |
|
|
|
|
|
AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_parameter, 0), 0, backend_arg.get()); |
|
|
|
|
|
MS_LOG(INFO) << "Assign mem of node" << backend_parameter->DebugString() << " of graph " |
|
|
|
|
|
<< AnfAlgo::GetGraphId(backend_parameter.get()) << " to node" << backend_arg->DebugString() |
|
|
|
|
|
<< "of graph " << AnfAlgo::GetGraphId(backend_arg.get()); |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
// if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device |
|
|
|
|
|
// type same to arg |
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) { |
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get()); |
|
|
|
|
|
} |
|
|
|
|
|
// if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph |
|
|
|
|
|
// unless it's a weight.If backend_parameter is a weight,we should assign the value back. |
|
|
|
|
|
AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_arg, 0), 0, backend_parameter.get()); |
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() |
|
|
|
|
|
<< "] will be replaced."; |
|
|
|
|
|
to_graph->ReplaceNode(backend_parameter, backend_arg); |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); |
|
|
|
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" |
|
|
|
|
|
<< backend_parameter->DebugString() << "of graph " << to_graph_id; |
|
|
|
|
|
(void)assigns_.insert(std::tuple<AnfNodePtr, GraphId, size_t>(front_anf, to_graph_id, input_idx)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter) { |
|
|
|
|
|
|
|
|
void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, |
|
|
|
|
|
size_t input_idx) { |
|
|
MS_LOG(INFO) << "Start!"; |
|
|
MS_LOG(INFO) << "Start!"; |
|
|
// sync data from host to device |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(front_tensor); |
|
|
|
|
|
size_t tensor_size = front_tensor->data().nbytes(); |
|
|
|
|
|
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(addr); |
|
|
|
|
|
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, |
|
|
|
|
|
front_tensor->data_type(), front_tensor->data_c(false))) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::pair<GraphId, size_t> graph_input_pair(to_graph_id, input_idx); |
|
|
|
|
|
initial_tenosrs_[graph_input_pair] = front_tensor; |
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
MS_LOG(INFO) << "Finish!"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -818,10 +877,9 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfN |
|
|
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { |
|
|
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { |
|
|
return input_index + output_num; |
|
|
return input_index + output_num; |
|
|
} |
|
|
} |
|
|
auto &graph_inputs = graph->inputs(); |
|
|
|
|
|
auto &valid_inputs = graph->ValidInputs(); |
|
|
|
|
|
|
|
|
auto valid_inputs = graph->valid_inputs(); |
|
|
if (valid_inputs[input_index]) { |
|
|
if (valid_inputs[input_index]) { |
|
|
SetChildGraphParameter(node, graph_inputs[input_index]); |
|
|
|
|
|
|
|
|
SetChildGraphParameter(node, graph->graph_id(), input_index); |
|
|
} else { |
|
|
} else { |
|
|
MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); |
|
|
MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); |
|
|
} |
|
|
} |
|
|
@@ -833,8 +891,7 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const Valu |
|
|
if (!value->isa<Tensor>()) { |
|
|
if (!value->isa<Tensor>()) { |
|
|
MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); |
|
|
MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); |
|
|
} |
|
|
} |
|
|
auto &graph_inputs = graph->inputs(); |
|
|
|
|
|
SetChildGraphParameter(value->cast<TensorPtr>(), graph_inputs[input_index]); |
|
|
|
|
|
|
|
|
SetChildGraphParameter(value->cast<TensorPtr>(), graph->graph_id(), input_index); |
|
|
return ++input_index; |
|
|
return ++input_index; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -905,8 +962,6 @@ GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { |
|
|
|
|
|
|
|
|
void AscendSession::MergeGraphExecOrder() { |
|
|
void AscendSession::MergeGraphExecOrder() { |
|
|
MS_LOG(INFO) << "Start!"; |
|
|
MS_LOG(INFO) << "Start!"; |
|
|
// insert switch to graph |
|
|
|
|
|
MergeSwitchCompile(); |
|
|
|
|
|
// merge graph order |
|
|
// merge graph order |
|
|
auto &graph_order = GetGraphOrder(final_graph_id_); |
|
|
auto &graph_order = GetGraphOrder(final_graph_id_); |
|
|
auto &graph_type = GetGraphOrderType(final_graph_id_); |
|
|
auto &graph_type = GetGraphOrderType(final_graph_id_); |
|
|
@@ -916,6 +971,13 @@ void AscendSession::MergeGraphExecOrder() { |
|
|
MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; |
|
|
MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
if (graph_order.size() > 1) { |
|
|
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
|
|
if (!context_ptr->enable_task_sink()) { |
|
|
|
|
|
MS_LOG(INFO) << "Control sink network should run with task-sink mode!"; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
// if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph |
|
|
// if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph |
|
|
SetStreamDistinctionLabel(final_graph, graph_order[0], false); |
|
|
SetStreamDistinctionLabel(final_graph, graph_order[0], false); |
|
|
std::vector<CNodePtr> final_exec_order = final_graph->execution_order(); |
|
|
std::vector<CNodePtr> final_exec_order = final_graph->execution_order(); |
|
|
@@ -930,7 +992,11 @@ void AscendSession::MergeGraphExecOrder() { |
|
|
MS_EXCEPTION_IF_NULL(child_graph); |
|
|
MS_EXCEPTION_IF_NULL(child_graph); |
|
|
auto exec_order = child_graph->execution_order(); |
|
|
auto exec_order = child_graph->execution_order(); |
|
|
MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; |
|
|
MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; |
|
|
(void)std::copy(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order)); |
|
|
|
|
|
|
|
|
(void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order), |
|
|
|
|
|
[&](CNodePtr node) -> CNodePtr { |
|
|
|
|
|
AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get()); |
|
|
|
|
|
return node; |
|
|
|
|
|
}); |
|
|
// add all value nodes of child graphs to final graph |
|
|
// add all value nodes of child graphs to final graph |
|
|
for (auto &value_node : child_graph->graph_value_nodes()) { |
|
|
for (auto &value_node : child_graph->graph_value_nodes()) { |
|
|
final_graph->AddValueNodeToGraph(value_node); |
|
|
final_graph->AddValueNodeToGraph(value_node); |
|
|
@@ -969,15 +1035,9 @@ void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from |
|
|
// generate a new cnode |
|
|
// generate a new cnode |
|
|
auto assign_node = graph->NewCNode(inputs); |
|
|
auto assign_node = graph->NewCNode(inputs); |
|
|
MS_EXCEPTION_IF_NULL(assign_node); |
|
|
MS_EXCEPTION_IF_NULL(assign_node); |
|
|
assign_node->set_abstract(std::make_shared<abstract::AbstractNone>()); |
|
|
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
|
|
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); |
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), assign_node.get()); |
|
|
|
|
|
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(graph), assign_node.get()); |
|
|
|
|
|
|
|
|
assign_node->set_abstract(to->abstract()); |
|
|
// append the assign at the end of from graph |
|
|
// append the assign at the end of from graph |
|
|
auto exec_order = graph->execution_order(); |
|
|
|
|
|
exec_order.push_back(assign_node); |
|
|
|
|
|
graph->set_execution_order(exec_order); |
|
|
|
|
|
|
|
|
InsertDependToGraph(graph_id, assign_node); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { |
|
|
void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { |
|
|
@@ -997,24 +1057,46 @@ void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodeP |
|
|
|
|
|
|
|
|
void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { |
|
|
void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { |
|
|
MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; |
|
|
MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; |
|
|
auto from_graph = graphs_[graph_id]; |
|
|
|
|
|
|
|
|
auto from_graph = GetGraph(graph_id); |
|
|
MS_EXCEPTION_IF_NULL(from_graph); |
|
|
MS_EXCEPTION_IF_NULL(from_graph); |
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("StreamActive"))}; |
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("StreamActive"))}; |
|
|
auto active_node = from_graph->NewCNode(inputs); |
|
|
auto active_node = from_graph->NewCNode(inputs); |
|
|
MS_EXCEPTION_IF_NULL(active_node); |
|
|
MS_EXCEPTION_IF_NULL(active_node); |
|
|
active_node->set_abstract(std::make_shared<abstract::AbstractNone>()); |
|
|
active_node->set_abstract(std::make_shared<abstract::AbstractNone>()); |
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
|
|
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); |
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), active_node.get()); |
|
|
|
|
|
// set the active stream id into the attr of active node |
|
|
// set the active stream id into the attr of active node |
|
|
std::vector<uint32_t> active_index_value = {}; |
|
|
std::vector<uint32_t> active_index_value = {}; |
|
|
active_index_value.push_back(actived_stream); |
|
|
active_index_value.push_back(actived_stream); |
|
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_value), active_node); |
|
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_value), active_node); |
|
|
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(from_graph), active_node.get()); |
|
|
|
|
|
// append the active node at the end of from graph |
|
|
// append the active node at the end of from graph |
|
|
auto exec_order = from_graph->execution_order(); |
|
|
|
|
|
exec_order.push_back(active_node); |
|
|
|
|
|
from_graph->set_execution_order(exec_order); |
|
|
|
|
|
|
|
|
auto return_node = from_graph->get_return(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
|
|
InsertControlDependToGraph(graph_id, return_node->input(1), active_node); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { |
|
|
|
|
|
MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString(); |
|
|
|
|
|
auto graph = GetGraph(graph_id); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; |
|
|
|
|
|
auto return_node = graph->get_return(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
|
|
inputs.push_back(return_node->input(1)); |
|
|
|
|
|
inputs.push_back(attch_node); |
|
|
|
|
|
auto depend_node = graph->NewCNode(inputs); |
|
|
|
|
|
return_node->set_input(1, depend_node); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, |
|
|
|
|
|
const AnfNodePtr &second_node) { |
|
|
|
|
|
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() |
|
|
|
|
|
<< ", the second node is " << second_node->DebugString(); |
|
|
|
|
|
auto graph = GetGraph(graph_id); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))}; |
|
|
|
|
|
inputs.push_back(first_node); |
|
|
|
|
|
inputs.push_back(second_node); |
|
|
|
|
|
auto control_depend = graph->NewCNode(inputs); |
|
|
|
|
|
InsertDependToGraph(graph_id, control_depend); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { |
|
|
size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { |
|
|
@@ -1043,5 +1125,29 @@ std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id) |
|
|
} |
|
|
} |
|
|
return graph_type_iter->second; |
|
|
return graph_type_iter->second; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::SyncInitialTenosrToDevice() { |
|
|
|
|
|
for (auto &item : initial_tenosrs_) { |
|
|
|
|
|
auto to_graph_id = item.first.first; |
|
|
|
|
|
auto input_idx = item.first.second; |
|
|
|
|
|
auto front_tensor = item.second; |
|
|
|
|
|
auto to_graph = GetGraph(to_graph_id); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(to_graph); |
|
|
|
|
|
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); |
|
|
|
|
|
if (input_idx >= graph_inputs.size()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size(); |
|
|
|
|
|
} |
|
|
|
|
|
auto backend_parameter = graph_inputs[input_idx]; |
|
|
|
|
|
// sync data from host to device |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(front_tensor); |
|
|
|
|
|
size_t tensor_size = front_tensor->data().nbytes(); |
|
|
|
|
|
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(addr); |
|
|
|
|
|
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, |
|
|
|
|
|
front_tensor->data_type(), front_tensor->data_c(false))) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} // namespace session |
|
|
} // namespace session |
|
|
} // namespace mindspore |
|
|
} // namespace mindspore |