| @@ -69,11 +69,16 @@ class DeviceAddress : public mindspore::DeviceSync { | |||||
| virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | ||||
| virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | ||||
| void *GetMutablePtr() const override { return ptr_; } | void *GetMutablePtr() const override { return ptr_; } | ||||
| // The related interface of reference count operation. | |||||
| void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; } | |||||
| size_t original_ref_count() const { return original_ref_count_; } | |||||
| void set_ref_count(size_t ref_count) { ref_count_ = ref_count; } | void set_ref_count(size_t ref_count) { ref_count_ = ref_count; } | ||||
| void IncreaseRefCount() { ref_count_++; } | |||||
| void DecreaseRefCountUsed() { ref_count_dynamic_used_--; } | |||||
| void ResetRefCountUsed() { ref_count_dynamic_used_ = ref_count_; } | |||||
| size_t ref_count_dynamic_used() const { return ref_count_dynamic_used_; } | |||||
| size_t ref_count() const { return ref_count_; } | |||||
| void IncreaseOriginalRefCount() { original_ref_count_++; } | |||||
| void DecreaseRefCount() { ref_count_--; } | |||||
| void ResetRefCount() { ref_count_ = original_ref_count_; } | |||||
| virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape, | virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape, | ||||
| TypeId host_type, bool trans_flag) const { | TypeId host_type, bool trans_flag) const { | ||||
| return true; | return true; | ||||
| @@ -91,9 +96,9 @@ class DeviceAddress : public mindspore::DeviceSync { | |||||
| void set_ptr(void *ptr) { ptr_ = ptr; } | void set_ptr(void *ptr) { ptr_ = ptr; } | ||||
| void *ptr_{nullptr}; | void *ptr_{nullptr}; | ||||
| size_t size_{0}; | size_t size_{0}; | ||||
| size_t original_ref_count_{1}; | |||||
| // It will be decreased in the running, and reset by original_ref_count_ when it is zero. | |||||
| size_t ref_count_{1}; | size_t ref_count_{1}; | ||||
| // It will be decreased in the running, and reset by ref_count_ when it is zero. | |||||
| size_t ref_count_dynamic_used_{1}; | |||||
| string format_{"DefaultFormat"}; | string format_{"DefaultFormat"}; | ||||
| TypeId type_id_{kNumberTypeFloat16}; | TypeId type_id_{kNumberTypeFloat16}; | ||||
| bool from_mem_pool_{false}; | bool from_mem_pool_{false}; | ||||
| @@ -36,7 +36,7 @@ using mindspore::device::DeviceContext; | |||||
| // The data source actor is used to fetch data from data source and process them into device tensors, | // The data source actor is used to fetch data from data source and process them into device tensors, | ||||
| // and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> AllocateMemory | // and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> AllocateMemory | ||||
| // -> OnMemoryAllocFinish -> SendOutput -> FreeMemory. | |||||
| // -> OnMemoryAllocFinish -> FreeMemory -> SendOutput. | |||||
| class DataSourceActor : public MemoryInterfaceActor { | class DataSourceActor : public MemoryInterfaceActor { | ||||
| public: | public: | ||||
| DataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context, | DataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context, | ||||
| @@ -37,7 +37,7 @@ using mindspore::kernel::AddressPtr; | |||||
| // The kernel actor is used to receive the device tensors and control info to luanch kernel. | // The kernel actor is used to receive the device tensors and control info to luanch kernel. | ||||
| // The processing flow is RunOpData/RunOpControl -> CheckLaunchCondition -> AllocateMemory | // The processing flow is RunOpData/RunOpControl -> CheckLaunchCondition -> AllocateMemory | ||||
| // -> OnMemoryAllocFinish -> LaunchKernel -> SendOutput -> FreeMemory. | |||||
| // -> OnMemoryAllocFinish -> LaunchKernel -> FreeMemory -> SendOutput. | |||||
| class KernelActor : public MemoryInterfaceActor { | class KernelActor : public MemoryInterfaceActor { | ||||
| public: | public: | ||||
| KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context, const AID memory_manager_aid) | KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context, const AID memory_manager_aid) | ||||
| @@ -50,13 +50,13 @@ void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> free_list, const | |||||
| for (auto &device_tensor : free_list) { | for (auto &device_tensor : free_list) { | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| // The reference count is decremented to zero to free memory, and reset to the original count. | // The reference count is decremented to zero to free memory, and reset to the original count. | ||||
| device_tensor->DecreaseRefCountUsed(); | |||||
| if (device_tensor->ref_count_dynamic_used() == 0) { | |||||
| device_tensor->DecreaseRefCount(); | |||||
| if (device_tensor->ref_count() == 0) { | |||||
| // Free memory through the device context. | // Free memory through the device context. | ||||
| if (device_tensor->GetPtr() != nullptr) { | if (device_tensor->GetPtr() != nullptr) { | ||||
| device_context->FreeMemory(device_tensor); | device_context->FreeMemory(device_tensor); | ||||
| } | } | ||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->ResetRefCount(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -78,10 +78,8 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker | |||||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | auto output_size = AnfAlgo::GetOutputTensorNum(item); | ||||
| for (size_t index = 0; index < output_size; index++) { | for (size_t index = 0; index < output_size; index++) { | ||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); | TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); | ||||
| // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown | |||||
| if (output_type_id == kTypeUnknown) { | if (output_type_id == kTypeUnknown) { | ||||
| MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; | |||||
| continue; | |||||
| output_type_id = AnfAlgo::GetOutputInferDataType(item, index); | |||||
| } | } | ||||
| size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | ||||
| @@ -212,13 +210,9 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt | |||||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(device_context_); | MS_EXCEPTION_IF_NULL(device_context_); | ||||
| // Optimization pass which is irrelevant to device type or format. | |||||
| device_context_->OptimizeGraphWithoutDeviceInfo(graph); | |||||
| device_context_->SetOperatorInfo(graph->execution_order()); | |||||
| // Optimization pass which is relevant to device type or format. | |||||
| device_context_->OptimizeGraphWithDeviceInfo(graph); | |||||
| // Execute optimization pass. | |||||
| device_context_->OptimizeGraph(graph); | |||||
| // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel, | // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel, | ||||
| // 'KernelMod' is real executive object of kernel. | // 'KernelMod' is real executive object of kernel. | ||||
| @@ -248,9 +242,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(device_context_); | MS_EXCEPTION_IF_NULL(device_context_); | ||||
| device_context_->SetOperatorInfo(graph->execution_order()); | |||||
| device_context_->OptimizeSingleOpGraph(graph); | device_context_->OptimizeSingleOpGraph(graph); | ||||
| MS_EXCEPTION_IF_NULL(session_); | MS_EXCEPTION_IF_NULL(session_); | ||||
| session_->RunOpHideNopNode(graph); | session_->RunOpHideNopNode(graph); | ||||
| session_->RunOpRemoveNopNode(graph); | session_->RunOpRemoveNopNode(graph); | ||||
| @@ -99,8 +99,8 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx); | auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| device_tensor->IncreaseRefCount(); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->IncreaseOriginalRefCount(); | |||||
| device_tensor->ResetRefCount(); | |||||
| } | } | ||||
| // The branch processing of PrepareDataForValueNode that value type is tensor. | // The branch processing of PrepareDataForValueNode that value type is tensor. | ||||
| @@ -252,8 +252,8 @@ BaseRef CreateOutputTensor(const session::KernelWithIndex &node_output_pair, con | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index); | const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| tensor->set_device_address(device_tensor); | tensor->set_device_address(device_tensor); | ||||
| device_tensor->set_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->set_original_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCount(); | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| } | } | ||||
| @@ -307,8 +307,8 @@ void AllocateContinuousMemoryForInput(const AnfNodePtr &kernel, const DeviceCont | |||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| // In the scene of communication op and computing op parallel multi stream, the input address of communication op | // In the scene of communication op and computing op parallel multi stream, the input address of communication op | ||||
| // can't be reused, so set the max reference count. | // can't be reused, so set the max reference count. | ||||
| device_tensor->set_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->set_original_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCount(); | |||||
| if (device_tensor->GetPtr() == nullptr) { | if (device_tensor->GetPtr() == nullptr) { | ||||
| is_need_alloc_memory = true; | is_need_alloc_memory = true; | ||||
| @@ -341,8 +341,8 @@ void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceCon | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| // One time application for continuous memory, so set the max reference count. | // One time application for continuous memory, so set the max reference count. | ||||
| device_tensor->set_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->set_original_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCount(); | |||||
| if (device_tensor->GetPtr() == nullptr) { | if (device_tensor->GetPtr() == nullptr) { | ||||
| is_need_alloc_memory = true; | is_need_alloc_memory = true; | ||||
| @@ -925,8 +925,8 @@ void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) { | |||||
| } | } | ||||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0); | auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0); | ||||
| DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor); | DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor); | ||||
| device_tensor->set_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->set_original_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCount(); | |||||
| } | } | ||||
| for (auto &input_node : graph->input_nodes()) { | for (auto &input_node : graph->input_nodes()) { | ||||
| @@ -935,8 +935,8 @@ void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) { | |||||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0); | auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor); | DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor); | ||||
| device_tensor->set_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| device_tensor->set_original_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCount(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1008,7 +1008,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false); | const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | ||||
| << "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n "; | |||||
| << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n "; | |||||
| } | } | ||||
| } else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) { | } else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) { | ||||
| // Dump the member info of host queue data source actor. | // Dump the member info of host queue data source actor. | ||||
| @@ -1021,7 +1021,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of | |||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope() | ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope() | ||||
| << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | ||||
| << "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n "; | |||||
| << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n "; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1065,7 +1065,7 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | ||||
| << "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n "; | |||||
| << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n "; | |||||
| } | } | ||||
| ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n "; | ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n "; | ||||
| @@ -55,10 +55,11 @@ DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t | |||||
| return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); | return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); | ||||
| } | } | ||||
| void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | |||||
| void CPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const { | |||||
| // Update Graph Dynamic Shape Attr. | // Update Graph Dynamic Shape Attr. | ||||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | ||||
| SetOperatorInfo(graph->execution_order()); | |||||
| OptimizeGraphImpl(graph); | OptimizeGraphImpl(graph); | ||||
| // Remove reorder after PS feature finish adapting push/pull in auto_monad. | // Remove reorder after PS feature finish adapting push/pull in auto_monad. | ||||
| @@ -67,7 +68,11 @@ void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &grap | |||||
| graph->set_execution_order(execution_order); | graph->set_execution_order(execution_order); | ||||
| } | } | ||||
| void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { OptimizeGraphImpl(graph); } | |||||
| void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| SetOperatorInfo(graph->execution_order()); | |||||
| OptimizeGraphImpl(graph); | |||||
| } | |||||
| void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const { | void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const { | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| @@ -41,7 +41,7 @@ class CPUDeviceContext : public DeviceContext { | |||||
| TypeId type_id) const override; | TypeId type_id) const override; | ||||
| DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; } | DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; } | ||||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | |||||
| void OptimizeGraph(const KernelGraphPtr &graph) const override; | |||||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | ||||
| void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; | void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; | ||||
| @@ -70,11 +70,8 @@ class DeviceContext { | |||||
| // Get device address type according different device type, such GPU, Ascend. | // Get device address type according different device type, such GPU, Ascend. | ||||
| virtual DeviceAddressType GetDeviceAddressType() const = 0; | virtual DeviceAddressType GetDeviceAddressType() const = 0; | ||||
| // The two functions below will be merged to one in the future. | |||||
| // General graph optimezer ignore device data type and format. | |||||
| virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {} | |||||
| // Optimize the kernel graph according to device data type and format. | |||||
| virtual void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {} | |||||
| // Optimize the kernel graph for graph mode. | |||||
| virtual void OptimizeGraph(const KernelGraphPtr &graph) const {} | |||||
| // Optimize the single operator graph for PyNative mode. | // Optimize the single operator graph for PyNative mode. | ||||
| virtual void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {} | virtual void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {} | ||||
| @@ -165,6 +165,17 @@ DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t | |||||
| return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id); | return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id); | ||||
| } | } | ||||
| void GPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| // Optimization pass which is irrelevant to device type or format. | |||||
| OptimizeGraphWithoutDeviceInfo(graph); | |||||
| SetOperatorInfo(graph->execution_order()); | |||||
| // Optimization pass which is relevant to device type or format. | |||||
| OptimizeGraphWithDeviceInfo(graph); | |||||
| } | |||||
| void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // Operator fusion optimization. | // Operator fusion optimization. | ||||
| @@ -240,6 +251,9 @@ void GPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> | |||||
| } | } | ||||
| void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { | void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| SetOperatorInfo(graph->execution_order()); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); | pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); | ||||
| @@ -48,11 +48,8 @@ class GPUDeviceContext : public DeviceContext { | |||||
| TypeId type_id) const override; | TypeId type_id) const override; | ||||
| DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; } | DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; } | ||||
| // General graph optimezer ignore device data type and format. | |||||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | |||||
| // Optimize the kernel graph according to device type, such format transform. | |||||
| void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const override; | |||||
| // Optimize the kernel graph for graph mode. | |||||
| void OptimizeGraph(const KernelGraphPtr &graph) const override; | |||||
| // Optimize the single operator graph for PyNative mode. | // Optimize the single operator graph for PyNative mode. | ||||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | ||||
| @@ -67,6 +64,11 @@ class GPUDeviceContext : public DeviceContext { | |||||
| DISABLE_COPY_AND_ASSIGN(GPUDeviceContext); | DISABLE_COPY_AND_ASSIGN(GPUDeviceContext); | ||||
| bool InitDevice(); | bool InitDevice(); | ||||
| // General graph optimezer ignore device data type and format. | |||||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const; | |||||
| // Optimize the kernel graph according to device type, such format transform. | |||||
| void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const; | |||||
| // Operator fusion optimization. | // Operator fusion optimization. | ||||
| void FuseOperators(const KernelGraphPtr &graph) const; | void FuseOperators(const KernelGraphPtr &graph) const; | ||||