From: @zyli2020 Reviewed-by: Signed-off-by:pull/14610/MERGE
| @@ -19,6 +19,8 @@ | |||
| #include <map> | |||
| #include <set> | |||
| #include <unordered_set> | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "base/core_ops.h" | |||
| @@ -480,6 +482,28 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||
| return 1; | |||
| } | |||
| size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | |||
| MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" | |||
| << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; | |||
| } | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | |||
| } | |||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | |||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | |||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | |||
| shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||
| shape = trans::TransShapeToDevice(shape, format); | |||
| } | |||
| // scalar's output shape is a empty vector | |||
| size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||
| return tensor_size; | |||
| } | |||
| std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| @@ -105,6 +105,8 @@ class AnfRuntimeAlgorithm { | |||
| static size_t GetInputTensorNum(const AnfNodePtr &node); | |||
| // get the num of output real_kernel(which can be build and run in device) | |||
| static size_t GetOutputTensorNum(const AnfNodePtr &node); | |||
| // Get the memory size of output tensor of node. | |||
| static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index); | |||
| // get all outputs format select of anf node | |||
| static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); | |||
| // get all inputs format select of anf node | |||
| @@ -16,7 +16,6 @@ | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/helper.h" | |||
| @@ -57,28 +56,6 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_ | |||
| return false; | |||
| } | |||
| size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | |||
| MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" | |||
| << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; | |||
| } | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | |||
| } | |||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | |||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | |||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | |||
| shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||
| shape = trans::TransShapeToDevice(shape, format); | |||
| } | |||
| // scalar's output shape is a empty vector | |||
| size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||
| return tensor_size; | |||
| } | |||
| void KernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -184,7 +161,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(item, index); | |||
| } | |||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | |||
| auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | |||
| auto device_address = | |||
| CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| @@ -361,7 +338,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| continue; | |||
| } | |||
| #endif | |||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | |||
| auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | |||
| device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | |||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { | |||
| @@ -656,7 +633,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||
| continue; | |||
| } | |||
| size_t tensor_size = tensor->data().nbytes(); | |||
| auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); | |||
| auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx); | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); | |||
| @@ -138,7 +138,6 @@ class KernelRuntime { | |||
| bool LaunchKernelMod(const session::KernelGraph &graph); | |||
| void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index); | |||
| static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); | |||
| size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); | |||
| void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph); | |||
| void RunOpAssignOutputMemory(const AnfNodePtr &kernel); | |||
| void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); | |||
| @@ -15,29 +15,200 @@ | |||
| */ | |||
| #include "runtime/framework/graph_compiler.h" | |||
| #include <numeric> | |||
| #include <map> | |||
| #include "runtime/framework/graph_scheduler.h" | |||
| #include "runtime/device/device_address.h" | |||
| #include "common/trans.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| void GraphCompiler::set_device_context(device::DeviceContext *device_context) { | |||
| namespace { | |||
| // Whether device address of anf node is valid and device address type | |||
| // is consistent with device type, for example, device address type | |||
| // DeviceAddressType::kGPU should be used on GPU device | |||
| bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||
| const auto &address = AnfAlgo::GetOutputAddr(kernel, index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| return address->DeviceType() == device_context->GetDeviceAddressType(); | |||
| } | |||
| return false; | |||
| } | |||
| void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> graph_inputs = graph->inputs(); | |||
| const std::vector<bool> &graph_valid_input = graph->valid_inputs(); | |||
| graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); | |||
| // Anf nodes which need create device address. | |||
| std::vector<AnfNodePtr> nodes_list; | |||
| for (size_t i = 0; i < graph_inputs.size(); ++i) { | |||
| AnfNodePtr item = graph_inputs[i]; | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| if (i < graph_valid_input.size() && !graph_valid_input[i]) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { | |||
| std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item); | |||
| for (const auto &out : outs) { | |||
| MS_EXCEPTION_IF_NULL(out); | |||
| if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) { | |||
| continue; | |||
| } | |||
| nodes_list.push_back(out); | |||
| } | |||
| } | |||
| if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) { | |||
| continue; | |||
| } | |||
| nodes_list.push_back(item); | |||
| } | |||
| // Create device address for anf node in nodes_list | |||
| for (const auto &item : nodes_list) { | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | |||
| for (size_t index = 0; index < output_size; index++) { | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); | |||
| // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown | |||
| if (output_type_id == kTypeUnknown) { | |||
| MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; | |||
| continue; | |||
| } | |||
| size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | |||
| auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size, | |||
| AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| AnfAlgo::SetOutputAddr(device_address, index, item.get()); | |||
| } | |||
| } | |||
| } | |||
| void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value, | |||
| size_t output_idx, const ValueNodePtr &value_node) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| MS_EXCEPTION_IF_NULL(node_value); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| const auto &ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| std::vector<tensor::TensorPtr> tensors; | |||
| TensorValueToTensor(node_value, &tensors); | |||
| for (const auto &tensor : tensors) { | |||
| if (tensor == nullptr) { | |||
| MS_LOG(WARNING) << "Tensor is null"; | |||
| return; | |||
| } | |||
| auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||
| if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) { | |||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++, | |||
| value_node.get()); | |||
| continue; | |||
| } | |||
| size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx); | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); | |||
| } | |||
| std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); | |||
| device::DeviceAddressPtr address = | |||
| device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | |||
| } | |||
| } | |||
| void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| for (const ValueNodePtr &value_node : graph->graph_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| if (NodeDeviceAddressExist(device_context, value_node, 0)) { | |||
| continue; | |||
| } | |||
| const auto &node_value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(node_value); | |||
| if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) { | |||
| CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node); | |||
| } else if (node_value->isa<StringImm>()) { | |||
| auto value = GetValue<std::string>(node_value); | |||
| size_t tensor_size = value.size(); | |||
| auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| AnfAlgo::SetOutputAddr(address, 0, value_node.get()); | |||
| } | |||
| } | |||
| } | |||
| void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| const std::vector<CNodePtr> &kernels = graph->execution_order(); | |||
| for (const auto &kernel : kernels) { | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | |||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | |||
| if (AnfAlgo::OutputAddrExist(kernel, i)) { | |||
| continue; | |||
| } | |||
| std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); | |||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); | |||
| auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); | |||
| AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); | |||
| } | |||
| } | |||
| } | |||
| void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| const std::vector<CNodePtr> &kernels = graph->execution_order(); | |||
| for (const auto &kernel : kernels) { | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | |||
| for (size_t i = 0; i < workspace_sizes.size(); ++i) { | |||
| auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown); | |||
| AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| void GraphCompiler::set_device_context(DeviceContext *device_context) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| device_context_ = device_context; | |||
| // The member variable 'session_' will be removed after removing session module. | |||
| if (session_ == nullptr) { | |||
| session_ = std::make_shared<session::SessionBasic>(); | |||
| const device::DeviceContextKey &device_context_key = device_context->device_context_key(); | |||
| session_->InitExecutor(device_context_key.device_name_, device_context_key.device_id_); | |||
| } | |||
| } | |||
| GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| // Generate kernel graph. | |||
| auto graph = session_->ConstructKernelGraph(nodes, outputs); | |||
| KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| return CompileGraphImpl(graph); | |||
| } | |||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { | |||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | |||
| MS_EXCEPTION_IF_NULL(device_context_); | |||
| // Optimization pass which is irrelevant to device type or format. | |||
| device_context_->OptimizeGraphWithoutDeviceInfo(graph); | |||
| @@ -51,6 +222,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { | |||
| // 'KernelMod' is real executive object of kernel. | |||
| device_context_->CreateKernel(graph->execution_order()); | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddress(graph); | |||
| // Transform graph to actor DAG, contains build and link. | |||
| GraphScheduler::GetInstance().Transform(graph, device_context_); | |||
| return graph->graph_id(); | |||
| @@ -68,7 +241,7 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph | |||
| } | |||
| // Generate kernel graph. | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); | |||
| KernelGraphPtr graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(device_context_); | |||
| @@ -82,6 +255,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph | |||
| // Generate 'KernelMod' for kernel in graph. | |||
| device_context_->CreateKernel(graph->execution_order()); | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddress(graph); | |||
| // Transform graph to actor DAG, contains build and link. | |||
| GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); | |||
| run_op_graphs_[graph_info] = graph; | |||
| @@ -101,5 +276,12 @@ KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const { | |||
| } | |||
| return iter->second; | |||
| } | |||
| void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const { | |||
| CreateParameterDeviceAddress(device_context_, graph); | |||
| CreateValueNodeDeviceAddress(device_context_, graph); | |||
| CreateKernelOutputDeviceAddress(device_context_, graph); | |||
| CreateKernelWorkspaceDeviceAddress(device_context_, graph); | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| using device::DeviceContext; | |||
| class GraphCompiler { | |||
| public: | |||
| static GraphCompiler &GetInstance() { | |||
| @@ -35,7 +36,7 @@ class GraphCompiler { | |||
| // Set device context which is initialized, the function must be called | |||
| // before using GraphCompiler and after changing device type or device id. | |||
| void set_device_context(device::DeviceContext *device_context); | |||
| void set_device_context(DeviceContext *device_context); | |||
| // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, | |||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||
| @@ -58,9 +59,12 @@ class GraphCompiler { | |||
| // The implementation of compiling graph in Graph Mode, including optimizing graph, | |||
| // setting operator info, creating kernel and transforming kernel graph to ActorSet. | |||
| GraphId CompileGraphImpl(const KernelGraphPtr &graph); | |||
| GraphId CompileGraphImpl(const KernelGraphPtr &graph) const; | |||
| device::DeviceContext *device_context_{nullptr}; | |||
| // Create device address for all anf nodes of graph. | |||
| void CreateDeviceAddress(const KernelGraphPtr &graph) const; | |||
| DeviceContext *device_context_{nullptr}; | |||
| // Single op kernel graph cache for PyNative mode. | |||
| std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_; | |||
| @@ -50,6 +50,11 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const { | |||
| address->ptr_ = nullptr; | |||
| } | |||
| DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) const { | |||
| return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); | |||
| } | |||
| void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | |||
| // Update Graph Dynamic Shape Attr. | |||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "runtime/hardware/device_context.h" | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| #include "runtime/device/memory_manager.h" | |||
| @@ -36,6 +37,10 @@ class CPUDeviceContext : public DeviceContext { | |||
| bool AllocateMemory(DeviceAddress *const &address, size_t size) const override; | |||
| void FreeMemory(DeviceAddress *const &address) const override; | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) const override; | |||
| DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; } | |||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | |||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | |||
| @@ -63,6 +63,13 @@ class DeviceContext { | |||
| return true; | |||
| } | |||
| // Create concrete device address according different device type. | |||
| virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) const = 0; | |||
| // Get device address type according different device type, such GPU, Ascend. | |||
| virtual DeviceAddressType GetDeviceAddressType() const = 0; | |||
| // The two functions below will be merged to one in the future. | |||
| // General graph optimezer ignore device data type and format. | |||
| virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {} | |||
| @@ -90,6 +97,9 @@ class DeviceContext { | |||
| // Devices that do not need stream could ignore the implementation of this function. | |||
| virtual bool SyncStream(size_t stream_id = 0) { return true; } | |||
| // Get device_context_key_ to obtain device name and device id. | |||
| const DeviceContextKey &device_context_key() const { return device_context_key_; } | |||
| protected: | |||
| DeviceContextKey device_context_key_; | |||
| }; | |||
| @@ -165,6 +165,11 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress | |||
| return true; | |||
| } | |||
| DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) const { | |||
| return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id); | |||
| } | |||
| void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // Operator fusion optimization. | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "runtime/hardware/device_context.h" | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| #include "runtime/device/memory_manager.h" | |||
| @@ -43,6 +44,10 @@ class GPUDeviceContext : public DeviceContext { | |||
| bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size, | |||
| const std::vector<size_t> &size_list) const override; | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) const override; | |||
| DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; } | |||
| // General graph optimezer ignore device data type and format. | |||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | |||
| // Optimize the kernel graph according to device type, such format transform. | |||