From: @zyli2020 Reviewed-by: Signed-off-by:pull/14610/MERGE
| @@ -19,6 +19,8 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <set> | #include <set> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <functional> | |||||
| #include <numeric> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| @@ -480,6 +482,28 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | |||||
| MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" | |||||
| << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; | |||||
| } | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); | |||||
| if (output_type_id == kTypeUnknown) { | |||||
| output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | |||||
| } | |||||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | |||||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | |||||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | |||||
| shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||||
| shape = trans::TransShapeToDevice(shape, format); | |||||
| } | |||||
| // scalar's output shape is a empty vector | |||||
| size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||||
| return tensor_size; | |||||
| } | |||||
| std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { | std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!AnfAlgo::IsRealKernel(node)) { | if (!AnfAlgo::IsRealKernel(node)) { | ||||
| @@ -105,6 +105,8 @@ class AnfRuntimeAlgorithm { | |||||
| static size_t GetInputTensorNum(const AnfNodePtr &node); | static size_t GetInputTensorNum(const AnfNodePtr &node); | ||||
| // get the num of output real_kernel(which can be build and run in device) | // get the num of output real_kernel(which can be build and run in device) | ||||
| static size_t GetOutputTensorNum(const AnfNodePtr &node); | static size_t GetOutputTensorNum(const AnfNodePtr &node); | ||||
| // Get the memory size of output tensor of node. | |||||
| static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index); | |||||
| // get all outputs format select of anf node | // get all outputs format select of anf node | ||||
| static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); | static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); | ||||
| // get all inputs format select of anf node | // get all inputs format select of anf node | ||||
| @@ -16,7 +16,6 @@ | |||||
| #include "runtime/device/kernel_runtime.h" | #include "runtime/device/kernel_runtime.h" | ||||
| #include <functional> | #include <functional> | ||||
| #include <numeric> | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| @@ -57,28 +56,6 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_ | |||||
| return false; | return false; | ||||
| } | } | ||||
| size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | |||||
| MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" | |||||
| << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; | |||||
| } | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); | |||||
| if (output_type_id == kTypeUnknown) { | |||||
| output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | |||||
| } | |||||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | |||||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | |||||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | |||||
| shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||||
| shape = trans::TransShapeToDevice(shape, format); | |||||
| } | |||||
| // scalar's output shape is a empty vector | |||||
| size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||||
| return tensor_size; | |||||
| } | |||||
| void KernelRuntime::AssignMemory(session::KernelGraph *graph) { | void KernelRuntime::AssignMemory(session::KernelGraph *graph) { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -184,7 +161,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> | |||||
| if (output_type_id == kTypeUnknown) { | if (output_type_id == kTypeUnknown) { | ||||
| output_type_id = AnfAlgo::GetOutputInferDataType(item, index); | output_type_id = AnfAlgo::GetOutputInferDataType(item, index); | ||||
| } | } | ||||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | |||||
| auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | |||||
| auto device_address = | auto device_address = | ||||
| CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| @@ -361,7 +338,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| #endif | #endif | ||||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | |||||
| auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | |||||
| device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | ||||
| MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | ||||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { | if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { | ||||
| @@ -656,7 +633,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||||
| continue; | continue; | ||||
| } | } | ||||
| size_t tensor_size = tensor->data().nbytes(); | size_t tensor_size = tensor->data().nbytes(); | ||||
| auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); | |||||
| auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx); | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); | TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); | ||||
| if (output_type_id == kTypeUnknown) { | if (output_type_id == kTypeUnknown) { | ||||
| output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); | output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); | ||||
| @@ -138,7 +138,6 @@ class KernelRuntime { | |||||
| bool LaunchKernelMod(const session::KernelGraph &graph); | bool LaunchKernelMod(const session::KernelGraph &graph); | ||||
| void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index); | void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index); | ||||
| static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); | static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); | ||||
| size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); | |||||
| void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph); | void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph); | ||||
| void RunOpAssignOutputMemory(const AnfNodePtr &kernel); | void RunOpAssignOutputMemory(const AnfNodePtr &kernel); | ||||
| void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); | void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); | ||||
| @@ -15,29 +15,200 @@ | |||||
| */ | */ | ||||
| #include "runtime/framework/graph_compiler.h" | #include "runtime/framework/graph_compiler.h" | ||||
| #include <numeric> | |||||
| #include <map> | |||||
| #include "runtime/framework/graph_scheduler.h" | #include "runtime/framework/graph_scheduler.h" | ||||
| #include "runtime/device/device_address.h" | |||||
| #include "common/trans.h" | |||||
| #include "utils/convert_utils.h" | |||||
| #include "ir/tensor.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace runtime { | namespace runtime { | ||||
| void GraphCompiler::set_device_context(device::DeviceContext *device_context) { | |||||
| namespace { | |||||
| // Whether device address of anf node is valid and device address type | |||||
| // is consistent with device type, for example, device address type | |||||
| // DeviceAddressType::kGPU should be used on GPU device | |||||
| bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||||
| const auto &address = AnfAlgo::GetOutputAddr(kernel, index); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| return address->DeviceType() == device_context->GetDeviceAddressType(); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> graph_inputs = graph->inputs(); | |||||
| const std::vector<bool> &graph_valid_input = graph->valid_inputs(); | |||||
| graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); | |||||
| // Anf nodes which need create device address. | |||||
| std::vector<AnfNodePtr> nodes_list; | |||||
| for (size_t i = 0; i < graph_inputs.size(); ++i) { | |||||
| AnfNodePtr item = graph_inputs[i]; | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| if (i < graph_valid_input.size() && !graph_valid_input[i]) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { | |||||
| std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item); | |||||
| for (const auto &out : outs) { | |||||
| MS_EXCEPTION_IF_NULL(out); | |||||
| if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) { | |||||
| continue; | |||||
| } | |||||
| nodes_list.push_back(out); | |||||
| } | |||||
| } | |||||
| if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) { | |||||
| continue; | |||||
| } | |||||
| nodes_list.push_back(item); | |||||
| } | |||||
| // Create device address for anf node in nodes_list | |||||
| for (const auto &item : nodes_list) { | |||||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | |||||
| for (size_t index = 0; index < output_size; index++) { | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); | |||||
| // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown | |||||
| if (output_type_id == kTypeUnknown) { | |||||
| MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; | |||||
| continue; | |||||
| } | |||||
| size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | |||||
| auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size, | |||||
| AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||||
| AnfAlgo::SetOutputAddr(device_address, index, item.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value, | |||||
| size_t output_idx, const ValueNodePtr &value_node) { | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| MS_EXCEPTION_IF_NULL(node_value); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| const auto &ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| std::vector<tensor::TensorPtr> tensors; | |||||
| TensorValueToTensor(node_value, &tensors); | |||||
| for (const auto &tensor : tensors) { | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(WARNING) << "Tensor is null"; | |||||
| return; | |||||
| } | |||||
| auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||||
| if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) { | |||||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++, | |||||
| value_node.get()); | |||||
| continue; | |||||
| } | |||||
| size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx); | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); | |||||
| if (output_type_id == kTypeUnknown) { | |||||
| output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); | |||||
| } | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); | |||||
| device::DeviceAddressPtr address = | |||||
| device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | |||||
| } | |||||
| } | |||||
| void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| for (const ValueNodePtr &value_node : graph->graph_value_nodes()) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| if (NodeDeviceAddressExist(device_context, value_node, 0)) { | |||||
| continue; | |||||
| } | |||||
| const auto &node_value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(node_value); | |||||
| if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) { | |||||
| CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node); | |||||
| } else if (node_value->isa<StringImm>()) { | |||||
| auto value = GetValue<std::string>(node_value); | |||||
| size_t tensor_size = value.size(); | |||||
| auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| AnfAlgo::SetOutputAddr(address, 0, value_node.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| const std::vector<CNodePtr> &kernels = graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | |||||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | |||||
| if (AnfAlgo::OutputAddrExist(kernel, i)) { | |||||
| continue; | |||||
| } | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); | |||||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); | |||||
| auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); | |||||
| AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| const std::vector<CNodePtr> &kernels = graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | |||||
| for (size_t i = 0; i < workspace_sizes.size(); ++i) { | |||||
| auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown); | |||||
| AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| void GraphCompiler::set_device_context(DeviceContext *device_context) { | |||||
| MS_EXCEPTION_IF_NULL(device_context); | MS_EXCEPTION_IF_NULL(device_context); | ||||
| device_context_ = device_context; | device_context_ = device_context; | ||||
| // The member variable 'session_' will be removed after removing session module. | // The member variable 'session_' will be removed after removing session module. | ||||
| if (session_ == nullptr) { | if (session_ == nullptr) { | ||||
| session_ = std::make_shared<session::SessionBasic>(); | session_ = std::make_shared<session::SessionBasic>(); | ||||
| const device::DeviceContextKey &device_context_key = device_context->device_context_key(); | |||||
| session_->InitExecutor(device_context_key.device_name_, device_context_key.device_id_); | |||||
| } | } | ||||
| } | } | ||||
| GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) { | GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) { | ||||
| MS_EXCEPTION_IF_NULL(session_); | MS_EXCEPTION_IF_NULL(session_); | ||||
| // Generate kernel graph. | // Generate kernel graph. | ||||
| auto graph = session_->ConstructKernelGraph(nodes, outputs); | |||||
| KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| return CompileGraphImpl(graph); | return CompileGraphImpl(graph); | ||||
| } | } | ||||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { | |||||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | |||||
| MS_EXCEPTION_IF_NULL(device_context_); | MS_EXCEPTION_IF_NULL(device_context_); | ||||
| // Optimization pass which is irrelevant to device type or format. | // Optimization pass which is irrelevant to device type or format. | ||||
| device_context_->OptimizeGraphWithoutDeviceInfo(graph); | device_context_->OptimizeGraphWithoutDeviceInfo(graph); | ||||
| @@ -51,6 +222,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { | |||||
| // 'KernelMod' is real executive object of kernel. | // 'KernelMod' is real executive object of kernel. | ||||
| device_context_->CreateKernel(graph->execution_order()); | device_context_->CreateKernel(graph->execution_order()); | ||||
| // Create device address for all anf nodes of graph. | |||||
| CreateDeviceAddress(graph); | |||||
| // Transform graph to actor DAG, contains build and link. | // Transform graph to actor DAG, contains build and link. | ||||
| GraphScheduler::GetInstance().Transform(graph, device_context_); | GraphScheduler::GetInstance().Transform(graph, device_context_); | ||||
| return graph->graph_id(); | return graph->graph_id(); | ||||
| @@ -68,7 +241,7 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph | |||||
| } | } | ||||
| // Generate kernel graph. | // Generate kernel graph. | ||||
| MS_EXCEPTION_IF_NULL(session_); | MS_EXCEPTION_IF_NULL(session_); | ||||
| auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); | |||||
| KernelGraphPtr graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(device_context_); | MS_EXCEPTION_IF_NULL(device_context_); | ||||
| @@ -82,6 +255,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph | |||||
| // Generate 'KernelMod' for kernel in graph. | // Generate 'KernelMod' for kernel in graph. | ||||
| device_context_->CreateKernel(graph->execution_order()); | device_context_->CreateKernel(graph->execution_order()); | ||||
| // Create device address for all anf nodes of graph. | |||||
| CreateDeviceAddress(graph); | |||||
| // Transform graph to actor DAG, contains build and link. | // Transform graph to actor DAG, contains build and link. | ||||
| GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); | GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); | ||||
| run_op_graphs_[graph_info] = graph; | run_op_graphs_[graph_info] = graph; | ||||
| @@ -101,5 +276,12 @@ KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const { | |||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const { | |||||
| CreateParameterDeviceAddress(device_context_, graph); | |||||
| CreateValueNodeDeviceAddress(device_context_, graph); | |||||
| CreateKernelOutputDeviceAddress(device_context_, graph); | |||||
| CreateKernelWorkspaceDeviceAddress(device_context_, graph); | |||||
| } | |||||
| } // namespace runtime | } // namespace runtime | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,6 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace runtime { | namespace runtime { | ||||
| using device::DeviceContext; | |||||
| class GraphCompiler { | class GraphCompiler { | ||||
| public: | public: | ||||
| static GraphCompiler &GetInstance() { | static GraphCompiler &GetInstance() { | ||||
| @@ -35,7 +36,7 @@ class GraphCompiler { | |||||
| // Set device context which is initialized, the function must be called | // Set device context which is initialized, the function must be called | ||||
| // before using GraphCompiler and after changing device type or device id. | // before using GraphCompiler and after changing device type or device id. | ||||
| void set_device_context(device::DeviceContext *device_context); | |||||
| void set_device_context(DeviceContext *device_context); | |||||
| // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, | // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, | ||||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | ||||
| @@ -58,9 +59,12 @@ class GraphCompiler { | |||||
| // The implementation of compiling graph in Graph Mode, including optimizing graph, | // The implementation of compiling graph in Graph Mode, including optimizing graph, | ||||
| // setting operator info, creating kernel and transforming kernel graph to ActorSet. | // setting operator info, creating kernel and transforming kernel graph to ActorSet. | ||||
| GraphId CompileGraphImpl(const KernelGraphPtr &graph); | |||||
| GraphId CompileGraphImpl(const KernelGraphPtr &graph) const; | |||||
| device::DeviceContext *device_context_{nullptr}; | |||||
| // Create device address for all anf nodes of graph. | |||||
| void CreateDeviceAddress(const KernelGraphPtr &graph) const; | |||||
| DeviceContext *device_context_{nullptr}; | |||||
| // Single op kernel graph cache for PyNative mode. | // Single op kernel graph cache for PyNative mode. | ||||
| std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_; | std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_; | ||||
| @@ -50,6 +50,11 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const { | |||||
| address->ptr_ = nullptr; | address->ptr_ = nullptr; | ||||
| } | } | ||||
| DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) const { | |||||
| return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); | |||||
| } | |||||
| void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | ||||
| // Update Graph Dynamic Shape Attr. | // Update Graph Dynamic Shape Attr. | ||||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "runtime/hardware/device_context.h" | #include "runtime/hardware/device_context.h" | ||||
| #include "runtime/hardware/device_context_manager.h" | #include "runtime/hardware/device_context_manager.h" | ||||
| #include "runtime/device/memory_manager.h" | #include "runtime/device/memory_manager.h" | ||||
| @@ -36,6 +37,10 @@ class CPUDeviceContext : public DeviceContext { | |||||
| bool AllocateMemory(DeviceAddress *const &address, size_t size) const override; | bool AllocateMemory(DeviceAddress *const &address, size_t size) const override; | ||||
| void FreeMemory(DeviceAddress *const &address) const override; | void FreeMemory(DeviceAddress *const &address) const override; | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) const override; | |||||
| DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; } | |||||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | ||||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | ||||
| @@ -63,6 +63,13 @@ class DeviceContext { | |||||
| return true; | return true; | ||||
| } | } | ||||
| // Create concrete device address according different device type. | |||||
| virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) const = 0; | |||||
| // Get device address type according different device type, such GPU, Ascend. | |||||
| virtual DeviceAddressType GetDeviceAddressType() const = 0; | |||||
| // The two functions below will be merged to one in the future. | // The two functions below will be merged to one in the future. | ||||
| // General graph optimezer ignore device data type and format. | // General graph optimezer ignore device data type and format. | ||||
| virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {} | virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {} | ||||
| @@ -90,6 +97,9 @@ class DeviceContext { | |||||
| // Devices that do not need stream could ignore the implementation of this function. | // Devices that do not need stream could ignore the implementation of this function. | ||||
| virtual bool SyncStream(size_t stream_id = 0) { return true; } | virtual bool SyncStream(size_t stream_id = 0) { return true; } | ||||
| // Get device_context_key_ to obtain device name and device id. | |||||
| const DeviceContextKey &device_context_key() const { return device_context_key_; } | |||||
| protected: | protected: | ||||
| DeviceContextKey device_context_key_; | DeviceContextKey device_context_key_; | ||||
| }; | }; | ||||
| @@ -165,6 +165,11 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress | |||||
| return true; | return true; | ||||
| } | } | ||||
| DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) const { | |||||
| return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id); | |||||
| } | |||||
| void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // Operator fusion optimization. | // Operator fusion optimization. | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "runtime/hardware/device_context.h" | #include "runtime/hardware/device_context.h" | ||||
| #include "runtime/hardware/device_context_manager.h" | #include "runtime/hardware/device_context_manager.h" | ||||
| #include "runtime/device/memory_manager.h" | #include "runtime/device/memory_manager.h" | ||||
| @@ -43,6 +44,10 @@ class GPUDeviceContext : public DeviceContext { | |||||
| bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size, | bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size, | ||||
| const std::vector<size_t> &size_list) const override; | const std::vector<size_t> &size_list) const override; | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) const override; | |||||
| DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; } | |||||
| // General graph optimezer ignore device data type and format. | // General graph optimezer ignore device data type and format. | ||||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | ||||
| // Optimize the kernel graph according to device type, such format transform. | // Optimize the kernel graph according to device type, such format transform. | ||||