|
|
|
@@ -417,6 +417,7 @@ bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { |
|
|
|
MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size(); |
|
|
|
const auto &device_context = |
|
|
|
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}); |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
const auto &new_segments = device_context->PartitionGraph(func_graph, segments); |
|
|
|
|
|
|
|
// Compile the whole function graph if not split graph. |
|
|
|
@@ -447,6 +448,7 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu |
|
|
|
const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]); |
|
|
|
const auto &device_context = |
|
|
|
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
device_context->Initialize(); |
|
|
|
|
|
|
|
// Transform nodes to inputs and outputs. |
|
|
|
@@ -487,6 +489,7 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const |
|
|
|
// Get the device context. |
|
|
|
const auto &device_context = |
|
|
|
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}); |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
device_context->Initialize(); |
|
|
|
|
|
|
|
bool single_op_cache_hit = true; |
|
|
|
@@ -527,6 +530,7 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con |
|
|
|
MS_EXCEPTION_IF_NULL(front_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(backend_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_compiler); |
|
|
|
MS_EXCEPTION_IF_NULL(args); |
|
|
|
size_t input_index = 0; |
|
|
|
auto inputs = front_cnode->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); i++) { |
|
|
|
@@ -596,12 +600,11 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<te |
|
|
|
MS_LOG(EXCEPTION) << "The input should be a tuple!"; |
|
|
|
} |
|
|
|
|
|
|
|
auto tuple_inputs = py::cast<py::tuple>(input_object); |
|
|
|
if (tuple_inputs.empty()) { |
|
|
|
auto inputs = py::cast<py::tuple>(input_object); |
|
|
|
if (inputs.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; |
|
|
|
} |
|
|
|
|
|
|
|
auto inputs = py::cast<py::tuple>(input_object); |
|
|
|
if (py::isinstance<tensor::Tensor>(inputs[0])) { |
|
|
|
PlantTensorTupleToVector(inputs, tensors); |
|
|
|
} else { |
|
|
|
@@ -615,12 +618,15 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co |
|
|
|
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info, |
|
|
|
VectorRef *op_outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
MS_EXCEPTION_IF_NULL(op_outputs); |
|
|
|
AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel); |
|
|
|
MS_EXCEPTION_IF_NULL(front_node); |
|
|
|
if (!front_node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode"; |
|
|
|
} |
|
|
|
CNodePtr cnode = front_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const std::vector<AnfNodePtr> &node_inputs = cnode->inputs(); |
|
|
|
if (node_inputs.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty"; |
|
|
|
@@ -633,6 +639,7 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co |
|
|
|
} |
|
|
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (prim->name() == kBpropCutOpName) { |
|
|
|
VectorRef args; |
|
|
|
GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info, |
|
|
|
@@ -798,7 +805,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, |
|
|
|
if (graph_iter == actor_to_graph_compiler_info_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can't find the graph compiler info."; |
|
|
|
} |
|
|
|
const auto &graph_compiler_info = *(graph_iter->second.get()); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_iter->second); |
|
|
|
const auto &graph_compiler_info = *(graph_iter->second); |
|
|
|
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_; |
|
|
|
|
|
|
|
// Transform args to input tensors. |
|
|
|
@@ -842,6 +850,9 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, |
|
|
|
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_; |
|
|
|
} |
|
|
|
|
|
|
|
if (graph_compiler_info.device_contexts_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "The device contexts is empty."; |
|
|
|
} |
|
|
|
// Sync device stream. |
|
|
|
const auto &first_device_context = graph_compiler_info.device_contexts_[0]; |
|
|
|
MS_EXCEPTION_IF_NULL(first_device_context); |
|
|
|
@@ -877,6 +888,7 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node, |
|
|
|
VectorRef *outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(output_node); |
|
|
|
MS_EXCEPTION_IF_NULL(outputs); |
|
|
|
MS_EXCEPTION_IF_NULL(output_position); |
|
|
|
// The makeTuple node need expand and recurse. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) { |
|
|
|
auto make_tuple = output_node->cast<CNodePtr>(); |
|
|
|
@@ -994,7 +1006,6 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con |
|
|
|
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo( |
|
|
|
const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask, |
|
|
|
const std::vector<tensor::TensorPtr> *input_tensors, bool need_erase) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph_compiler_); |
|
|
|
std::vector<KernelGraphPtr> graphs; |
|
|
|
std::vector<DeviceContext *> device_contexts; |
|
|
|
runtime::KernelMapPosition outputs_order; |
|
|
|
@@ -1027,10 +1038,12 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo( |
|
|
|
} |
|
|
|
|
|
|
|
void MindRTBackend::EraseSingleOpCache(const ActorInfo &actor_info, const KernelGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
if (graph_info_to_device_context_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "The map graph_info_to_device_context_ is empty."; |
|
|
|
} |
|
|
|
const auto &graph_info = graph_info_to_device_context_.begin()->first; |
|
|
|
MS_EXCEPTION_IF_NULL(graph_compiler_); |
|
|
|
graph_compiler_->EraseSingleOpCache(graph_info, graph->graph_id()); |
|
|
|
actor_to_graph_compiler_info_.erase(actor_info); |
|
|
|
} |
|
|
|
@@ -1045,6 +1058,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info |
|
|
|
if (graph_iter == actor_to_graph_compiler_info_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can't find the graph compiler info."; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(graph_iter->second); |
|
|
|
const auto &graph_compiler_info = *(graph_iter->second); |
|
|
|
|
|
|
|
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info); |
|
|
|
|