| @@ -128,6 +128,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| // Get graph by graph id, if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||
| void ClearGraph(); | |||
| // create a single run op graph | |||
| std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask, bool is_ascend = false); | |||
| void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | |||
| void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const; | |||
| void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const; | |||
| #ifdef ENABLE_DEBUGGER | |||
| // set debugger | |||
| void SetDebugger() { | |||
| @@ -163,12 +170,12 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | |||
| virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0; | |||
| virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | |||
| virtual void UnifyMindIR(const KernelGraphPtr &graph) {} | |||
| virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } | |||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | |||
| virtual void BuildGraphImpl(GraphId) {} | |||
| virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) = 0; | |||
| virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| } | |||
| virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) {} | |||
| @@ -183,7 +190,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | |||
| void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | |||
| void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) const; | |||
| void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const; | |||
| @@ -191,10 +197,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| // create graph output for RunOp | |||
| void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph); | |||
| CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); | |||
| // create a single run op graph | |||
| std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask, bool is_ascend = false); | |||
| // Generate graph info for a single op graph | |||
| GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors); | |||
| void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info); | |||
| @@ -219,8 +221,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); | |||
| void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph); | |||
| void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs); | |||
| void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const; | |||
| void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const; | |||
| virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; } | |||
| void InitAllBucket(const KernelGraphPtr &graph); | |||
| void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor); | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"){} | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/framework/graph_compiler.h" | |||
| #include "runtime/framework/graph_scheduler.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| void GraphCompiler::set_device_context(device::DeviceContext *device_context) { | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| device_context_ = device_context; | |||
| // The member variable 'session_' will be removed after removing session module. | |||
| if (session_ == nullptr) { | |||
| session_ = std::make_shared<session::SessionBasic>(); | |||
| } | |||
| } | |||
| GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| // Generate kernel graph. | |||
| auto graph = session_->ConstructKernelGraph(nodes, outputs); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| return CompileGraphImpl(graph); | |||
| } | |||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(device_context_); | |||
| // Optimization pass which is irrelevant to device type or format. | |||
| device_context_->OptimizeGraphWithoutDeviceInfo(graph); | |||
| device_context_->SetOperatorInfo(graph->execution_order()); | |||
| // Optimization pass which is relevant to device type or format. | |||
| device_context_->OptimizeGraphWithDeviceInfo(graph); | |||
| // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel, | |||
| // 'KernelMod' is real executive object of kernel. | |||
| device_context_->CreateKernel(graph->execution_order()); | |||
| // Transform graph to actor DAG, contains build and link. | |||
| GraphScheduler::GetInstance().Transform(graph, device_context_); | |||
| return graph->graph_id(); | |||
| } | |||
| void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| auto graph = session_->GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto actor_set = GraphScheduler::GetInstance().Fetch(graph); | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| GraphScheduler::GetInstance().Run(actor_set); | |||
| } | |||
| void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, | |||
| const std::vector<int64_t> &tensors_mask, VectorRef *outputs) { | |||
| // Check if the graph cache exists. | |||
| if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) { | |||
| // Prepare the graph | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(device_context_); | |||
| device_context_->SetOperatorInfo(graph->execution_order()); | |||
| device_context_->OptimizeSingleOpGraph(graph); | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->RunOpHideNopNode(graph); | |||
| device_context_->CreateKernel(graph->execution_order()); | |||
| run_op_graphs_[graph_info] = graph; | |||
| } | |||
| session_->EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| // wait for allreduce | |||
| for (auto &tensor : *input_tensors) { | |||
| if (tensor->NeedWaitDevice()) { | |||
| tensor->WaitDevice(); | |||
| } | |||
| } | |||
| // run op | |||
| auto graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| session_->RunOpRemoveNopNode(graph); | |||
| GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); | |||
| auto actor_set = GraphScheduler::GetInstance().Fetch(graph); | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep); | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -45,7 +45,7 @@ class GraphCompiler { | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||
| // Construct single op kernel graph, compile and run the kernel graph in PyNative mode. | |||
| void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| void CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask, | |||
| VectorRef *outputs); | |||
| @@ -61,7 +61,7 @@ class GraphCompiler { | |||
| device::DeviceContext *device_context_{nullptr}; | |||
| // Single op kernel graph cache for PyNative mode. | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||
| std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_; | |||
| // The member variable 'session_' will be removed after removing session module. | |||
| session::SessionPtr session_{nullptr}; | |||
| @@ -21,6 +21,11 @@ | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #include "runtime/device/cpu/kernel_select_cpu.h" | |||
| #include "utils/trace_base.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/pass_manager.h" | |||
| #include "backend/optimizer/cpu/insert_cast_cpu.h" | |||
| #include "backend/optimizer/pass/replace_node_by_proxy.h" | |||
| #include "backend/optimizer/pass/erase_visit_attr.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -45,6 +50,40 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const { | |||
| address->ptr_ = nullptr; | |||
| } | |||
| void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | |||
| // Update Graph Dynamic Shape Attr. | |||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | |||
| OptimizeGraphImpl(graph); | |||
| // Remove reorder after PS feature finish adapting push/pull in auto_monad. | |||
| auto execution_order = graph->execution_order(); | |||
| AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order)); | |||
| graph->set_execution_order(execution_order); | |||
| } | |||
| void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { OptimizeGraphImpl(graph); } | |||
| void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::InsertCastCPU>()); | |||
| pm->AddPass(std::make_shared<opt::EraseVisitAttr>()); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| } | |||
| void CPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const { | |||
| for (const auto &cnode : graph->execution_order()) { | |||
| if (AnfAlgo::IsNodeDynamicShape(cnode)) { | |||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode); | |||
| MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope(); | |||
| } | |||
| } | |||
| graph->UpdateGraphDynamicAttr(); | |||
| } | |||
| void CPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const { | |||
| for (const auto &node : nodes) { | |||
| SetKernelInfo(node); | |||
| @@ -36,15 +36,23 @@ class CPUDeviceContext : public DeviceContext { | |||
| bool AllocateMemory(DeviceAddress *const &address, size_t size) const override; | |||
| void FreeMemory(DeviceAddress *const &address) const override; | |||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | |||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | |||
| void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; | |||
| void CreateKernel(const std::vector<CNodePtr> &nodes) const override; | |||
| bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override; | |||
| private: | |||
| DISABLE_COPY_AND_ASSIGN(CPUDeviceContext); | |||
| // Update Graph Dynamic Shape Attr. | |||
| void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const; | |||
| void OptimizeGraphImpl(const KernelGraphPtr &graph) const; | |||
| uint32_t device_id_; | |||
| std::shared_ptr<MemoryManager> mem_manager_; | |||
| bool initialized_; | |||
| }; | |||
| @@ -63,17 +63,23 @@ class DeviceContext { | |||
| return true; | |||
| } | |||
| // Optimize the kernel graph according to different devices. | |||
| virtual void OptimizeGraph(const KernelGraphPtr &graph) const {} | |||
| // The two functions below will be merged to one in the future. | |||
| // General graph optimezer ignore device data type and format. | |||
| virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {} | |||
| // Optimize the kernel graph according to device data type and format. | |||
| virtual void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {} | |||
| // Optimize the single operator graph for PyNative mode. | |||
| virtual void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {} | |||
| // Select the matching backend kernels according to the data type and format of input and output for all | |||
| // execution operators, and set final device data type and format information for backend kernels, device | |||
| // data type and format which replace original data type and format will use for executing kernels. | |||
| virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {} | |||
| virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const = 0; | |||
| // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel, | |||
| // 'KernelMod' is real executive object of kernel. | |||
| virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {} | |||
| virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const = 0; | |||
| // Launch a kernel via 'KernelMod' of the kernel. | |||
| virtual bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs, | |||
| @@ -34,7 +34,7 @@ void DeviceContextManager::ClearDeviceContexts() { | |||
| device_contexts_.clear(); | |||
| } | |||
| DeviceContext *DeviceContextManager::GetDeviceContext(const DeviceContextKey &device_context_key) { | |||
| DeviceContext *DeviceContextManager::CreateOrGetDeviceContext(const DeviceContextKey &device_context_key) { | |||
| std::string device_context_key_str = device_context_key.ToString(); | |||
| std::lock_guard<std::mutex> guard(lock_); | |||
| @@ -36,7 +36,7 @@ class DeviceContextManager { | |||
| return instance; | |||
| } | |||
| void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator); | |||
| DeviceContext *GetDeviceContext(const DeviceContextKey &device_info); | |||
| DeviceContext *CreateOrGetDeviceContext(const DeviceContextKey &device_context_key); | |||
| void ClearDeviceContexts(); | |||
| private: | |||
| @@ -27,16 +27,31 @@ | |||
| #include "runtime/device/gpu/gpu_buffer_mgr.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "runtime/device/gpu/gpu_common.h" | |||
| #include "runtime/hardware/gpu/optimizer.h" | |||
| #include "common/trans.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| bool GPUDeviceContext::Initialize() { | |||
| if (initialized_ == true) { | |||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_context_key_.device_id_)), | |||
| "Failed to set device id"); | |||
| GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); | |||
| return true; | |||
| } | |||
| // Set device id | |||
| const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); | |||
| bool collective_inited = CollectiveInitializer::instance().collective_inited(); | |||
| if (collective_inited && collective_handle_ != nullptr) { | |||
| auto get_local_rank_funcptr = | |||
| reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id")); | |||
| MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); | |||
| device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)()); | |||
| } | |||
| // Set device id and initialize device resource. | |||
| bool ret = InitDevice(); | |||
| if (!ret) { | |||
| @@ -50,8 +65,6 @@ bool GPUDeviceContext::Initialize() { | |||
| mem_manager_->MallocDeviceMemory(); | |||
| // Initialize NCCL. | |||
| const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); | |||
| bool collective_inited = CollectiveInitializer::instance().collective_inited(); | |||
| if (collective_inited && collective_handle_ != nullptr) { | |||
| auto init_nccl_comm_funcptr = | |||
| reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm")); | |||
| @@ -152,6 +165,97 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress | |||
| return true; | |||
| } | |||
| void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // Operator fusion optimization. | |||
| FuseOperators(graph); | |||
| device::gpu::AssignGpuStream(graph); | |||
| // Update Graph Dynamic Shape Attr. | |||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| const bool pynative_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode; | |||
| // Hide NopOp from execution graph in graph mode | |||
| if (!pynative_mode) { | |||
| opt::HideNopNode(graph.get()); | |||
| } | |||
| } | |||
| void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const { | |||
| // Graph optimization relevant to device data format | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | |||
| pm->AddPass(std::make_shared<opt::PostBatchNormAddReluFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>()); | |||
| pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); | |||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | |||
| pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); | |||
| pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>()); | |||
| pm->AddPass(std::make_shared<opt::ReluV2Pass>()); | |||
| pm->AddPass(std::make_shared<opt::AddReluV2Fusion>()); | |||
| pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>()); | |||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | |||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | |||
| pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| } | |||
| void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | |||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>()); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum")); | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||
| pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce")); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| // Graph kernel fusion optimization | |||
| if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { | |||
| return; | |||
| } | |||
| opt::GraphKernelOptimize(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| } | |||
| void GPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const { | |||
| for (const auto &cnode : graph->execution_order()) { | |||
| if (AnfAlgo::IsNodeDynamicShape(cnode)) { | |||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode); | |||
| MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope(); | |||
| } | |||
| } | |||
| graph->UpdateGraphDynamicAttr(); | |||
| } | |||
| void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| } | |||
| void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const { | |||
| for (const auto &node : nodes) { | |||
| SetKernelInfo(node); | |||
| @@ -43,6 +43,14 @@ class GPUDeviceContext : public DeviceContext { | |||
| bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size, | |||
| const std::vector<size_t> &size_list) const override; | |||
| // General graph optimezer ignore device data type and format. | |||
| void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; | |||
| // Optimize the kernel graph according to device type, such format transform. | |||
| void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const override; | |||
| // Optimize the single operator graph for PyNative mode. | |||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; | |||
| void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; | |||
| void CreateKernel(const std::vector<CNodePtr> &nodes) const override; | |||
| bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs, | |||
| @@ -54,6 +62,12 @@ class GPUDeviceContext : public DeviceContext { | |||
| DISABLE_COPY_AND_ASSIGN(GPUDeviceContext); | |||
| bool InitDevice(); | |||
| // Operator fusion optimization. | |||
| void FuseOperators(const KernelGraphPtr &graph) const; | |||
| // Update Graph Dynamic Shape Attr. | |||
| void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const; | |||
| std::shared_ptr<MemoryManager> mem_manager_; | |||
| std::vector<void *> streams_; | |||
| bool initialized_; | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_ | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/pass_manager.h" | |||
| #include "backend/optimizer/common/common_backend_optimization.h" | |||
| #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | |||
| #include "backend/optimizer/gpu/adam_fusion.h" | |||
| #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | |||
| #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" | |||
| #include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" | |||
| #include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" | |||
| #include "backend/optimizer/gpu/combine_momentum_fusion.h" | |||
| #include "backend/optimizer/gpu/combine_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/cudnn_inplace_fusion.h" | |||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||
| #include "backend/optimizer/gpu/print_reduce_fusion.h" | |||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | |||
| #include "backend/optimizer/gpu/reduce_precision_fusion.h" | |||
| #include "backend/optimizer/gpu/relu_v2_pass.h" | |||
| #include "backend/optimizer/gpu/add_relu_v2_fusion.h" | |||
| #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" | |||
| #include "backend/optimizer/pass/communication_op_fusion.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_ | |||