diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index fe363145d0..47673a3714 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -77,6 +77,15 @@ class GpuKernel : public KernelMod { } return GetValue(attr); } + template + inline T GetAttrWithDefault(const CNodePtr &kernel_node, const std::string &key, const T &value) const { + const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); + const ValuePtr &attr = prim->GetAttr(key); + if (attr == nullptr) { + return value; + } + return GetValue(attr); + } // expand Nd Shape to 4d (N in [0,4]) void ShapeNdTo4d(const std::vector &src, std::vector *dst) { if (src.size() > 4) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h index 2bd32ce0e1..d9490d8464 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -54,7 +54,8 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { output_size_(0), padded_size_(0), workspace_size_(0), - use_pad_(true) {} + use_pad_(true), + beta_(0) {} ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -75,13 +76,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { } const float alpha = 1; - const float beta = 0; if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { T *padded = GetDeviceAddress(workspace, 1); CHECK_CUDNN_RET_WITH_EXCEPT( cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, padded_descriptor_, padded), + workspace_size_, &beta_, padded_descriptor_, padded), "ConvolutionBackwardData failed"); if (data_format_ == "NHWC") { CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, @@ -93,7 +93,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { } else { CHECK_CUDNN_RET_WITH_EXCEPT( cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, dx_desc_, dx), + workspace_size_, &beta_, dx_desc_, dx), "ConvolutionBackwardData failed"); } return true; @@ -188,6 +188,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { "cudnnSetConvolutionMathType failed.") } SelectAlgorithm(dx_desc_real); + beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; InitSizeLists(); return true; } @@ -349,6 +350,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { size_t padded_size_; size_t workspace_size_; bool use_pad_; + float beta_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h index cbf15d0688..20c18ab81d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h @@ -46,7 +46,8 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { scale_bias_diff_desc_(nullptr), activation_desc_(nullptr), handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} + cudnn_data_type_(CUDNN_DATA_FLOAT), + beta_data_diff_(0) {} ~FusedBatchNormGradExGpuKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -88,12 +89,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { } const float alpha_data_diff = 1; - const float beta_data_diff = 0; const float alpha_param_diff = 1; const float beta_param_diff = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationBackwardEx( - handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, + handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_), @@ -141,6 +140,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { return true; } std::string format = AnfAlgo::GetInputFormat(kernel_node, 0); + beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; SetTensorDescriptor(format, shape); InitSizeLists(); return true; @@ -285,6 +285,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { cudnnHandle_t handle_; cudnnDataType_t cudnn_data_type_; + float beta_data_diff_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc new file mode 100644 index 0000000000..9d7eec3233 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/gpu/cudnn_inplace_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "utils/contract.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/gpu/kernel_info_setter.h" + +namespace mindspore { +namespace opt { +namespace { +struct AnfNodeIndex { + AnfNodeIndex() : node(nullptr), index(0) {} + AnfNodeIndex(const AnfNodePtr &n, const int i) : node(n), index(i) {} + AnfNodePtr node; + uint32_t index; +}; + +// opname, output idx +std::map kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, + {kFusedBatchNormGradExWithAddAndActivation, 3}}; + +std::set kSkipOpNames = { + kTensorAddOpName, +}; + +// opname, input idx +std::map kAggregatesOpNames = { + {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; + +template +void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { + auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node); + MS_EXCEPTION_IF_NULL(primitive); + primitive->AddAttr(key, MakeValue(value)); +} + +void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector *inplace_node) { + SetPrimAttr(aggregate_node.node, "aggregate", true); + SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index); + SetPrimAttr(skip_node, "skip", true); + + static uint32_t group = 0; + for (size_t i = 0; i < inplace_node->size(); i++) { + auto algo = (i == 0) ? "cover" : "accumulation"; + SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo); + SetPrimAttr((*inplace_node)[i].node, "inplace_group", group); + SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index); + } + group++; +} + +void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector &inplace_nodes) { + std::vector inputs1 = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), + inplace_nodes[0].node, inplace_nodes[1].node}; + auto control_depend_node = graph->NewCNode(inputs1); + + auto return_node = graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + // mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor` + auto return_input = return_node->input(kFirstDataInputIndex)->cast(); + MS_EXCEPTION_IF_NULL(return_input); + std::vector inputs2 = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + return_input->input(kFirstDataInputIndex), control_depend_node}; + auto depend_node = graph->NewCNode(inputs2); + return_node->set_input(kFirstDataInputIndex, depend_node); +} + +bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, + std::vector *inplace) { + MS_EXCEPTION_IF_NULL(skip_node); + MS_EXCEPTION_IF_NULL(aggregate); + if (!node->isa()) { + return false; + } + auto aggregate_iter = kAggregatesOpNames.find(AnfAlgo::GetCNodeName(node)); + if (aggregate_iter == kAggregatesOpNames.end()) { + return false; + } + aggregate->node = node; + aggregate->index = aggregate_iter->second; + + *skip_node = AnfAlgo::GetInputNode(utils::cast(node), aggregate_iter->second); + if (*skip_node == nullptr || !(*skip_node)->isa() || + kSkipOpNames.count(AnfAlgo::GetCNodeName(*skip_node)) == 0 || + GetRealNodeUsedList(graph, *skip_node)->size() >= 2) { + return false; + } + + auto cnode = (*skip_node)->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { + auto inplace_node = AnfAlgo::GetInputNode(utils::cast(*skip_node), i); + if (!inplace_node->isa()) { + return false; + } + // Check Inplace nodes have no user except TensorAdd nodes + if (GetRealNodeUsedList(graph, inplace_node)->size() >= 2) { + return false; + } + + // skip TupleGetItem node + if (AnfAlgo::GetCNodeName(inplace_node) == prim::kPrimTupleGetItem->name()) { + inplace_node = AnfAlgo::GetInputNode(utils::cast(inplace_node), 0); + } + + auto inplace_iter = kInplaceOpNames.find(AnfAlgo::GetCNodeName(inplace_node)); + if (inplace_iter == kInplaceOpNames.end()) { + return false; + } + + inplace->push_back(AnfNodeIndex(inplace_node, inplace_iter->second)); + } + + return true; +} + +std::map TopoIndex(const std::vector &node_list) { + std::map topo_index; + for (size_t i = 0; i < node_list.size(); i++) { + topo_index.insert(make_pair(node_list[i], i)); + } + return topo_index; +} +} // namespace + +bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector node_list = TopoSort(graph->get_return()); + auto topo_index = TopoIndex(node_list); + + for (auto node : node_list) { + AnfNodeIndex aggregate_node; + AnfNodePtr skip_node; + std::vector inplace_node; + // 1. Pattern Match. + if (!PatternMatch(graph, node, &aggregate_node, &skip_node, &inplace_node)) { + continue; + } + + // 2. Keep the original topological order in case the dependence between inplace nodes + std::sort(inplace_node.begin(), inplace_node.end(), [&topo_index](const AnfNodeIndex &n1, const AnfNodeIndex &n2) { + auto iter1 = topo_index.find(n1.node); + auto iter2 = topo_index.find(n2.node); + if (iter1 == topo_index.end() || iter2 == topo_index.end()) { + MS_LOG(EXCEPTION) << ": Node not existed in topo order. node1: " << n1.node->DebugString() + << ", node2: " << n2.node->DebugString(); + } + + if (iter1->second < iter2->second) { + return true; + } + return false; + }); + MS_LOG(INFO) << "[inplace optimizer] aggregate node: " << aggregate_node.index << ", " + << aggregate_node.node->DebugString() << "; skip node: " << skip_node->DebugString() << std::endl + << "; inplace node 0: " << inplace_node[0].index << ", " << inplace_node[0].node->DebugString() + << std::endl + << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString() + << std::endl; + // 2. Set Node attr + SetNodeAttr(aggregate_node, skip_node, &inplace_node); + // 3. Set dependence for inplace nodes + InsertControlDependToGraph(graph, inplace_node); + } + + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.h new file mode 100644 index 0000000000..3b36a49913 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class CudnnInplaceAggregate : public Pass { + public: + CudnnInplaceAggregate() : Pass("cudnn_inplace_aggregate") {} + ~CudnnInplaceAggregate() override = default; + bool Run(const FuncGraphPtr &g) override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc index 1de2021490..5290d44213 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -18,6 +18,7 @@ #include #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_runtime_manager.h" namespace mindspore { namespace device { @@ -191,6 +192,33 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { return adjacent_with_communication_op; } +bool MemSwapManager::IsInplaceRelevantOp(const TensorInfo &tensor) { + MS_EXCEPTION_IF_NULL(tensor.kernel_); + if (AnfAlgo::IsInplaceNode(tensor.kernel_, "inplace_algo") || AnfAlgo::IsInplaceNode(tensor.kernel_, "skip")) { + return true; + } + + MS_EXCEPTION_IF_NULL(kernel_graph_); + const auto &graph_manager = kernel_graph_->manager(); + MS_EXCEPTION_IF_NULL(graph_manager); + NodeUsersMap &user_map = graph_manager->node_users(); + + auto users = user_map.find(tensor.kernel_); + for (const auto &user : users->second) { + if (!AnfAlgo::IsInplaceNode(user.first, "aggregate")) { + continue; + } + + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user.first, user.second); + if (tensor.output_idx_ == kernel_with_index.second) { + MS_LOG(INFO) << " [inplace optimizer] tensor: " << tensor.kernel_->DebugString() + << "output idx: " << tensor.output_idx_ << " used by aggregate node: " << user.first->DebugString(); + return true; + } + } + return false; +} + void MemSwapManager::SaveUserKernelTopoOrder() { MS_EXCEPTION_IF_NULL(kernel_graph_); const auto &graph_manager = kernel_graph_->manager(); @@ -231,7 +259,10 @@ void MemSwapManager::AddSwapInfo() { } const AnfNodePtr &kernel = tensor.kernel_; - if (IsCommunicationRelevantOp(kernel)) { + bool filter = IsCommunicationRelevantOp(kernel) || IsInplaceRelevantOp(tensor); + if (filter) { + MS_LOG(INFO) << " [inplace optimizer] ignore swap tensor: " << kernel->DebugString() << ", index" + << tensor.output_idx_; continue; } diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h index fa2da8d721..b50000779e 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h @@ -136,6 +136,8 @@ class MemSwapManager { bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; + bool IsInplaceRelevantOp(const TensorInfo &tensor); + std::vector execution_order_; std::vector ordered_tensors_; std::unordered_map kernel_execution_info_; diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 47b82d1435..e14382ce95 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1031,6 +1031,21 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i node->set_input(index + 1, input_node); } +bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) { + MS_EXCEPTION_IF_NULL(kernel); + auto primitive = AnfAlgo::GetCNodePrimitive(kernel); + if (!primitive) { + return false; + } + + auto inplace_attr = primitive->GetAttr(type); + if (inplace_attr == nullptr) { + return false; + } + + return true; +} + bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index d4a5f00a25..017afe036c 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -206,6 +206,7 @@ class AnfRuntimeAlgorithm { // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsInplaceNode(const AnfNodePtr &node, const string &type); static bool IsGetNext(const NotNull &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallSwitchKernelGraph(const CNodePtr &cnode); diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 294a97e4f9..03b53892ea 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -39,6 +39,7 @@ #include "backend/optimizer/gpu/insert_format_transform_op.h" #include "backend/optimizer/gpu/remove_format_transform_pair.h" #include "backend/optimizer/gpu/remove_redundant_format_transform.h" +#include "backend/optimizer/gpu/cudnn_inplace_fusion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 5e35021438..3e04258d75 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -15,6 +15,7 @@ */ #include "runtime/device/gpu/gpu_kernel_runtime.h" #include +#include #include "pybind11/pybind11.h" #include "runtime/device/gpu/gpu_device_address.h" #include "runtime/device/gpu/cuda_driver.h" @@ -279,6 +280,51 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v ClearOutputAddress(inputs, value_nodes, execution_order); } +void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) { + std::map> inplace_groups; + auto kernel_cnodes = graph->execution_order(); + for (auto &kernel : kernel_cnodes) { + if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) { + continue; + } + auto primitive = AnfAlgo::GetCNodePrimitive(kernel); + auto group_attr = primitive->GetAttr("inplace_group"); + MS_EXCEPTION_IF_NULL(group_attr); + auto group_id = GetValue(group_attr); + inplace_groups[group_id].push_back(kernel); + } + + for (auto &group : inplace_groups) { + auto &item = group.second; + // in-place compute when group size >= 2. + if (item.size() < 2) { + continue; + } + + auto primitive = AnfAlgo::GetCNodePrimitive(item[0]); + auto output_index = GetValue(primitive->GetAttr("inplace_output_index")); + auto device_address = AnfAlgo::GetMutableOutputAddr(item[0], output_index, false); + if (device_address->GetPtr() != nullptr) { + continue; + } + + auto kernel_mod = AnfAlgo::GetKernelMod(item[0]); + auto output_size = kernel_mod->GetOutputSizeList(); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_size[output_index]); + if (!ret) { + MS_LOG(EXCEPTION) << "Cannot alloc address, tensor size is: " << output_size[output_index]; + } + + for (auto &node : item) { + auto prim = AnfAlgo::GetCNodePrimitive(node); + auto index = GetValue(prim->GetAttr("inplace_output_index")); + AnfAlgo::SetOutputAddr(device_address, index, node.get()); + MS_LOG(INFO) << "[inplace optimizer] group id: " << group.first << ", node: " << node->DebugString() + << ", output_index: " << index; + } + } +} + void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -545,6 +591,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De mem_reuse_util_->ResetDynamicUsedRefCount(); // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); + AllocInplaceNodeMemory(graph); #ifdef ENABLE_DEBUGGER debugger_ = debugger; @@ -562,6 +609,10 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De for (const auto &kernel : kernels) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); + if (AnfAlgo::IsInplaceNode(kernel, "skip")) { + MS_LOG(INFO) << "[inplace optimizer] skip node: " << kernel->DebugString(); + continue; + } AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; @@ -808,6 +859,19 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); } + + // Get in-place output_address + if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { + auto primitive = AnfAlgo::GetCNodePrimitive(kernel); + auto input_index = GetValue(primitive->GetAttr("aggregate_input_index")); + if (i == input_index) { + auto skip_node = AnfAlgo::GetInputNode(utils::cast(kernel), input_index); + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(skip_node, 0, false); + MS_LOG(INFO) << "[inplace optimizer] aggregate: " << kernel->DebugString() + << ", skip: " << skip_node->DebugString() << ", address: " << device_address->GetMutablePtr(); + } + } + MS_EXCEPTION_IF_NULL(device_address); UpdateHostSwapInQueue(device_address, mock); MS_EXCEPTION_IF_NULL(device_address->ptr_); @@ -969,6 +1033,14 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) } // Free the input of kernel by reference count. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { + auto primitive = AnfAlgo::GetCNodePrimitive(kernel); + auto index = GetValue(primitive->GetAttr("aggregate_input_index")); + if (i == index) { + continue; + } + } + auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 022143fd68..8adc434110 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -94,6 +94,7 @@ class GPUKernelRuntime : public KernelRuntime { void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); void UpdateHostSwapOutQueue(bool mock); void ClearSwapInfo(bool mock); + void AllocInplaceNodeMemory(const session::KernelGraph *graph); std::unordered_map mem_reuse_util_map_; std::unordered_map mem_swap_map_; std::unordered_map is_first_step_map_; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4377d72409..e06e45d85e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -193,6 +193,7 @@ constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; constexpr auto kPaddingOpName = "Padding"; constexpr auto kAvgPoolOpName = "AvgPool"; constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; +constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; constexpr auto kTensorAddOpName = "TensorAdd"; constexpr auto kCastOpName = "Cast"; constexpr auto kGreaterEqualOpName = "GreaterEqual"; diff --git a/tests/st/ops/gpu/test_cudnn_inplace_fusion.py b/tests/st/ops/gpu/test_cudnn_inplace_fusion.py new file mode 100644 index 0000000000..d6fd28f9d3 --- /dev/null +++ b/tests/st/ops/gpu/test_cudnn_inplace_fusion.py @@ -0,0 +1,75 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G + + +class Conv2dBpropInputInplace(nn.Cell): + def __init__(self, w1, w2): + super(Conv2dBpropInputInplace, self).__init__() + self.conv2d_1 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1) + self.w1 = Parameter(initializer(w1, w1.shape), name='w1') + self.conv2d_2 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1) + self.w2 = Parameter(initializer(w2, w2.shape), name='w2') + self.add = P.TensorAdd() + self.maxpool = P.MaxPool(ksize=3, strides=2, padding='SAME') + self.maxpool_grad = G.MaxPoolGrad(ksize=3, strides=2, padding='SAME') + self.shape = (32, 64, 56, 56) + + def construct(self, x1, x2, x3): + dx1 = self.conv2d_1(x1, self.w1, self.shape) + dx2 = self.conv2d_2(x2, self.w2, self.shape) + + dx = self.add(dx1, dx2) + y = self.maxpool(x3) + y = self.maxpool_grad(x3, y, dx) + return y + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_inplace_fusion1(): + + np.random.seed(42) + w1_np = np.random.randn(64, 64, 1, 1) + w2_np = np.random.randn(256, 64, 1, 1) + x1_np = np.random.randn(32, 64, 56, 56) + x2_np = np.random.randn(32, 256, 56, 56) + x3_np = np.random.randn(32, 64, 112, 112) + + w1 = Tensor(w1_np.astype(np.float32)) + w2 = Tensor(w2_np.astype(np.float32)) + x1 = Tensor(x1_np.astype(np.float32)) + x2 = Tensor(x2_np.astype(np.float32)) + x3 = Tensor(x3_np.astype(np.float32)) + + net = Conv2dBpropInputInplace(w1, w2) + context.set_context(device_target='GPU', mode=context.GRAPH_MODE) + fusion_output = net(x1, x2, x3) + + context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE) + no_fusion_output = net(x1, x2, x3) + + assert np.allclose(fusion_output.asnumpy(), no_fusion_output.asnumpy(), atol=2e-5)