| @@ -77,6 +77,15 @@ class GpuKernel : public KernelMod { | |||||
| } | } | ||||
| return GetValue<T>(attr); | return GetValue<T>(attr); | ||||
| } | } | ||||
| template <typename T> | |||||
| inline T GetAttrWithDefault(const CNodePtr &kernel_node, const std::string &key, const T &value) const { | |||||
| const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| const ValuePtr &attr = prim->GetAttr(key); | |||||
| if (attr == nullptr) { | |||||
| return value; | |||||
| } | |||||
| return GetValue<T>(attr); | |||||
| } | |||||
| // expand Nd Shape to 4d (N in [0,4]) | // expand Nd Shape to 4d (N in [0,4]) | ||||
| void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) { | void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) { | ||||
| if (src.size() > 4) { | if (src.size() > 4) { | ||||
| @@ -54,7 +54,8 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| output_size_(0), | output_size_(0), | ||||
| padded_size_(0), | padded_size_(0), | ||||
| workspace_size_(0), | workspace_size_(0), | ||||
| use_pad_(true) {} | |||||
| use_pad_(true), | |||||
| beta_(0) {} | |||||
| ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } | ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -75,13 +76,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| } | } | ||||
| const float alpha = 1; | const float alpha = 1; | ||||
| const float beta = 0; | |||||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | ||||
| T *padded = GetDeviceAddress<T>(workspace, 1); | T *padded = GetDeviceAddress<T>(workspace, 1); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | ||||
| workspace_size_, &beta, padded_descriptor_, padded), | |||||
| workspace_size_, &beta_, padded_descriptor_, padded), | |||||
| "ConvolutionBackwardData failed"); | "ConvolutionBackwardData failed"); | ||||
| if (data_format_ == "NHWC") { | if (data_format_ == "NHWC") { | ||||
| CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, | CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, | ||||
| @@ -93,7 +93,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| } else { | } else { | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | ||||
| workspace_size_, &beta, dx_desc_, dx), | |||||
| workspace_size_, &beta_, dx_desc_, dx), | |||||
| "ConvolutionBackwardData failed"); | "ConvolutionBackwardData failed"); | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -188,6 +188,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| "cudnnSetConvolutionMathType failed.") | "cudnnSetConvolutionMathType failed.") | ||||
| } | } | ||||
| SelectAlgorithm(dx_desc_real); | SelectAlgorithm(dx_desc_real); | ||||
| beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -349,6 +350,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| size_t padded_size_; | size_t padded_size_; | ||||
| size_t workspace_size_; | size_t workspace_size_; | ||||
| bool use_pad_; | bool use_pad_; | ||||
| float beta_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,7 +46,8 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||||
| scale_bias_diff_desc_(nullptr), | scale_bias_diff_desc_(nullptr), | ||||
| activation_desc_(nullptr), | activation_desc_(nullptr), | ||||
| handle_(nullptr), | handle_(nullptr), | ||||
| cudnn_data_type_(CUDNN_DATA_FLOAT) {} | |||||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||||
| beta_data_diff_(0) {} | |||||
| ~FusedBatchNormGradExGpuKernel() override { DestroyResource(); } | ~FusedBatchNormGradExGpuKernel() override { DestroyResource(); } | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -88,12 +89,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| const float alpha_data_diff = 1; | const float alpha_data_diff = 1; | ||||
| const float beta_data_diff = 0; | |||||
| const float alpha_param_diff = 1; | const float alpha_param_diff = 1; | ||||
| const float beta_param_diff = 0; | const float beta_param_diff = 0; | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationBackwardEx( | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationBackwardEx( | ||||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, | |||||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, | |||||
| &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, | &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, | ||||
| scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance, | scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance, | ||||
| activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_), | activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_), | ||||
| @@ -141,6 +140,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::string format = AnfAlgo::GetInputFormat(kernel_node, 0); | std::string format = AnfAlgo::GetInputFormat(kernel_node, 0); | ||||
| beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; | |||||
| SetTensorDescriptor(format, shape); | SetTensorDescriptor(format, shape); | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -285,6 +285,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||||
| cudnnHandle_t handle_; | cudnnHandle_t handle_; | ||||
| cudnnDataType_t cudnn_data_type_; | cudnnDataType_t cudnn_data_type_; | ||||
| float beta_data_diff_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -0,0 +1,195 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/cudnn_inplace_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <map> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "utils/contract.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| struct AnfNodeIndex { | |||||
| AnfNodeIndex() : node(nullptr), index(0) {} | |||||
| AnfNodeIndex(const AnfNodePtr &n, const int i) : node(n), index(i) {} | |||||
| AnfNodePtr node; | |||||
| uint32_t index; | |||||
| }; | |||||
| // opname, output idx | |||||
| std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, | |||||
| {kFusedBatchNormGradExWithAddAndActivation, 3}}; | |||||
| std::set<string> kSkipOpNames = { | |||||
| kTensorAddOpName, | |||||
| }; | |||||
| // opname, input idx | |||||
| std::map<string, uint32_t> kAggregatesOpNames = { | |||||
| {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; | |||||
| template <typename T> | |||||
| void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| primitive->AddAttr(key, MakeValue(value)); | |||||
| } | |||||
| void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node) { | |||||
| SetPrimAttr(aggregate_node.node, "aggregate", true); | |||||
| SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index); | |||||
| SetPrimAttr(skip_node, "skip", true); | |||||
| static uint32_t group = 0; | |||||
| for (size_t i = 0; i < inplace_node->size(); i++) { | |||||
| auto algo = (i == 0) ? "cover" : "accumulation"; | |||||
| SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo); | |||||
| SetPrimAttr((*inplace_node)[i].node, "inplace_group", group); | |||||
| SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index); | |||||
| } | |||||
| group++; | |||||
| } | |||||
| void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) { | |||||
| std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||||
| inplace_nodes[0].node, inplace_nodes[1].node}; | |||||
| auto control_depend_node = graph->NewCNode(inputs1); | |||||
| auto return_node = graph->get_return(); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| // mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor` | |||||
| auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(return_input); | |||||
| std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), | |||||
| return_input->input(kFirstDataInputIndex), control_depend_node}; | |||||
| auto depend_node = graph->NewCNode(inputs2); | |||||
| return_node->set_input(kFirstDataInputIndex, depend_node); | |||||
| } | |||||
| bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, | |||||
| std::vector<AnfNodeIndex> *inplace) { | |||||
| MS_EXCEPTION_IF_NULL(skip_node); | |||||
| MS_EXCEPTION_IF_NULL(aggregate); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto aggregate_iter = kAggregatesOpNames.find(AnfAlgo::GetCNodeName(node)); | |||||
| if (aggregate_iter == kAggregatesOpNames.end()) { | |||||
| return false; | |||||
| } | |||||
| aggregate->node = node; | |||||
| aggregate->index = aggregate_iter->second; | |||||
| *skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), aggregate_iter->second); | |||||
| if (*skip_node == nullptr || !(*skip_node)->isa<CNode>() || | |||||
| kSkipOpNames.count(AnfAlgo::GetCNodeName(*skip_node)) == 0 || | |||||
| GetRealNodeUsedList(graph, *skip_node)->size() >= 2) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = (*skip_node)->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||||
| auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i); | |||||
| if (!inplace_node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| // Check Inplace nodes have no user except TensorAdd nodes | |||||
| if (GetRealNodeUsedList(graph, inplace_node)->size() >= 2) { | |||||
| return false; | |||||
| } | |||||
| // skip TupleGetItem node | |||||
| if (AnfAlgo::GetCNodeName(inplace_node) == prim::kPrimTupleGetItem->name()) { | |||||
| inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(inplace_node), 0); | |||||
| } | |||||
| auto inplace_iter = kInplaceOpNames.find(AnfAlgo::GetCNodeName(inplace_node)); | |||||
| if (inplace_iter == kInplaceOpNames.end()) { | |||||
| return false; | |||||
| } | |||||
| inplace->push_back(AnfNodeIndex(inplace_node, inplace_iter->second)); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::map<AnfNodePtr, int> TopoIndex(const std::vector<AnfNodePtr> &node_list) { | |||||
| std::map<AnfNodePtr, int> topo_index; | |||||
| for (size_t i = 0; i < node_list.size(); i++) { | |||||
| topo_index.insert(make_pair(node_list[i], i)); | |||||
| } | |||||
| return topo_index; | |||||
| } | |||||
| } // namespace | |||||
| bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | |||||
| auto topo_index = TopoIndex(node_list); | |||||
| for (auto node : node_list) { | |||||
| AnfNodeIndex aggregate_node; | |||||
| AnfNodePtr skip_node; | |||||
| std::vector<AnfNodeIndex> inplace_node; | |||||
| // 1. Pattern Match. | |||||
| if (!PatternMatch(graph, node, &aggregate_node, &skip_node, &inplace_node)) { | |||||
| continue; | |||||
| } | |||||
| // 2. Keep the original topological order in case the dependence between inplace nodes | |||||
| std::sort(inplace_node.begin(), inplace_node.end(), [&topo_index](const AnfNodeIndex &n1, const AnfNodeIndex &n2) { | |||||
| auto iter1 = topo_index.find(n1.node); | |||||
| auto iter2 = topo_index.find(n2.node); | |||||
| if (iter1 == topo_index.end() || iter2 == topo_index.end()) { | |||||
| MS_LOG(EXCEPTION) << ": Node not existed in topo order. node1: " << n1.node->DebugString() | |||||
| << ", node2: " << n2.node->DebugString(); | |||||
| } | |||||
| if (iter1->second < iter2->second) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| MS_LOG(INFO) << "[inplace optimizer] aggregate node: " << aggregate_node.index << ", " | |||||
| << aggregate_node.node->DebugString() << "; skip node: " << skip_node->DebugString() << std::endl | |||||
| << "; inplace node 0: " << inplace_node[0].index << ", " << inplace_node[0].node->DebugString() | |||||
| << std::endl | |||||
| << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString() | |||||
| << std::endl; | |||||
| // 2. Set Node attr | |||||
| SetNodeAttr(aggregate_node, skip_node, &inplace_node); | |||||
| // 3. Set dependence for inplace nodes | |||||
| InsertControlDependToGraph(graph, inplace_node); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class CudnnInplaceAggregate : public Pass { | |||||
| public: | |||||
| CudnnInplaceAggregate() : Pass("cudnn_inplace_aggregate") {} | |||||
| ~CudnnInplaceAggregate() override = default; | |||||
| bool Run(const FuncGraphPtr &g) override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ | |||||
| @@ -18,6 +18,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "runtime/device/kernel_runtime_manager.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -191,6 +192,33 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { | |||||
| return adjacent_with_communication_op; | return adjacent_with_communication_op; | ||||
| } | } | ||||
| bool MemSwapManager::IsInplaceRelevantOp(const TensorInfo &tensor) { | |||||
| MS_EXCEPTION_IF_NULL(tensor.kernel_); | |||||
| if (AnfAlgo::IsInplaceNode(tensor.kernel_, "inplace_algo") || AnfAlgo::IsInplaceNode(tensor.kernel_, "skip")) { | |||||
| return true; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_); | |||||
| const auto &graph_manager = kernel_graph_->manager(); | |||||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||||
| NodeUsersMap &user_map = graph_manager->node_users(); | |||||
| auto users = user_map.find(tensor.kernel_); | |||||
| for (const auto &user : users->second) { | |||||
| if (!AnfAlgo::IsInplaceNode(user.first, "aggregate")) { | |||||
| continue; | |||||
| } | |||||
| auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user.first, user.second); | |||||
| if (tensor.output_idx_ == kernel_with_index.second) { | |||||
| MS_LOG(INFO) << " [inplace optimizer] tensor: " << tensor.kernel_->DebugString() | |||||
| << "output idx: " << tensor.output_idx_ << " used by aggregate node: " << user.first->DebugString(); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void MemSwapManager::SaveUserKernelTopoOrder() { | void MemSwapManager::SaveUserKernelTopoOrder() { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph_); | MS_EXCEPTION_IF_NULL(kernel_graph_); | ||||
| const auto &graph_manager = kernel_graph_->manager(); | const auto &graph_manager = kernel_graph_->manager(); | ||||
| @@ -231,7 +259,10 @@ void MemSwapManager::AddSwapInfo() { | |||||
| } | } | ||||
| const AnfNodePtr &kernel = tensor.kernel_; | const AnfNodePtr &kernel = tensor.kernel_; | ||||
| if (IsCommunicationRelevantOp(kernel)) { | |||||
| bool filter = IsCommunicationRelevantOp(kernel) || IsInplaceRelevantOp(tensor); | |||||
| if (filter) { | |||||
| MS_LOG(INFO) << " [inplace optimizer] ignore swap tensor: " << kernel->DebugString() << ", index" | |||||
| << tensor.output_idx_; | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -136,6 +136,8 @@ class MemSwapManager { | |||||
| bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; | bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; | ||||
| bool IsInplaceRelevantOp(const TensorInfo &tensor); | |||||
| std::vector<CNodePtr> execution_order_; | std::vector<CNodePtr> execution_order_; | ||||
| std::vector<TensorInfo> ordered_tensors_; | std::vector<TensorInfo> ordered_tensors_; | ||||
| std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_; | std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_; | ||||
| @@ -1031,6 +1031,21 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i | |||||
| node->set_input(index + 1, input_node); | node->set_input(index + 1, input_node); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||||
| if (!primitive) { | |||||
| return false; | |||||
| } | |||||
| auto inplace_attr = primitive->GetAttr(type); | |||||
| if (inplace_attr == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { | bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -206,6 +206,7 @@ class AnfRuntimeAlgorithm { | |||||
| // get real input index for some tbe ops which input order is different between me and tbe impl | // get real input index for some tbe ops which input order is different between me and tbe impl | ||||
| static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | ||||
| static bool IsCommunicationOp(const AnfNodePtr &node); | static bool IsCommunicationOp(const AnfNodePtr &node); | ||||
| static bool IsInplaceNode(const AnfNodePtr &node, const string &type); | |||||
| static bool IsGetNext(const NotNull<AnfNodePtr> &node); | static bool IsGetNext(const NotNull<AnfNodePtr> &node); | ||||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | ||||
| static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode); | static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode); | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | #include "backend/optimizer/gpu/insert_format_transform_op.h" | ||||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | #include "backend/optimizer/gpu/remove_format_transform_pair.h" | ||||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | ||||
| #include "backend/optimizer/gpu/cudnn_inplace_fusion.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | ||||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "runtime/device/gpu/gpu_kernel_runtime.h" | #include "runtime/device/gpu/gpu_kernel_runtime.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | |||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "runtime/device/gpu/gpu_device_address.h" | #include "runtime/device/gpu/gpu_device_address.h" | ||||
| #include "runtime/device/gpu/cuda_driver.h" | #include "runtime/device/gpu/cuda_driver.h" | ||||
| @@ -279,6 +280,51 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v | |||||
| ClearOutputAddress(inputs, value_nodes, execution_order); | ClearOutputAddress(inputs, value_nodes, execution_order); | ||||
| } | } | ||||
| void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) { | |||||
| std::map<uint32_t, std::vector<CNodePtr>> inplace_groups; | |||||
| auto kernel_cnodes = graph->execution_order(); | |||||
| for (auto &kernel : kernel_cnodes) { | |||||
| if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) { | |||||
| continue; | |||||
| } | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||||
| auto group_attr = primitive->GetAttr("inplace_group"); | |||||
| MS_EXCEPTION_IF_NULL(group_attr); | |||||
| auto group_id = GetValue<uint32_t>(group_attr); | |||||
| inplace_groups[group_id].push_back(kernel); | |||||
| } | |||||
| for (auto &group : inplace_groups) { | |||||
| auto &item = group.second; | |||||
| // in-place compute when group size >= 2. | |||||
| if (item.size() < 2) { | |||||
| continue; | |||||
| } | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(item[0]); | |||||
| auto output_index = GetValue<uint32_t>(primitive->GetAttr("inplace_output_index")); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(item[0], output_index, false); | |||||
| if (device_address->GetPtr() != nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(item[0]); | |||||
| auto output_size = kernel_mod->GetOutputSizeList(); | |||||
| auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_size[output_index]); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Cannot alloc address, tensor size is: " << output_size[output_index]; | |||||
| } | |||||
| for (auto &node : item) { | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(node); | |||||
| auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index")); | |||||
| AnfAlgo::SetOutputAddr(device_address, index, node.get()); | |||||
| MS_LOG(INFO) << "[inplace optimizer] group id: " << group.first << ", node: " << node->DebugString() | |||||
| << ", output_index: " << index; | |||||
| } | |||||
| } | |||||
| } | |||||
| void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -545,6 +591,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De | |||||
| mem_reuse_util_->ResetDynamicUsedRefCount(); | mem_reuse_util_->ResetDynamicUsedRefCount(); | ||||
| // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | ||||
| AllocCommunicationOpDynamicRes(graph); | AllocCommunicationOpDynamicRes(graph); | ||||
| AllocInplaceNodeMemory(graph); | |||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| debugger_ = debugger; | debugger_ = debugger; | ||||
| @@ -562,6 +609,10 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De | |||||
| for (const auto &kernel : kernels) { | for (const auto &kernel : kernels) { | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| if (AnfAlgo::IsInplaceNode(kernel, "skip")) { | |||||
| MS_LOG(INFO) << "[inplace optimizer] skip node: " << kernel->DebugString(); | |||||
| continue; | |||||
| } | |||||
| AddressPtrList kernel_inputs; | AddressPtrList kernel_inputs; | ||||
| AddressPtrList kernel_workspaces; | AddressPtrList kernel_workspaces; | ||||
| AddressPtrList kernel_outputs; | AddressPtrList kernel_outputs; | ||||
| @@ -808,6 +859,19 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k | |||||
| // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. | // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. | ||||
| device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); | device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); | ||||
| } | } | ||||
| // Get in-place output_address | |||||
| if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||||
| auto input_index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index")); | |||||
| if (i == input_index) { | |||||
| auto skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(kernel), input_index); | |||||
| device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(skip_node, 0, false); | |||||
| MS_LOG(INFO) << "[inplace optimizer] aggregate: " << kernel->DebugString() | |||||
| << ", skip: " << skip_node->DebugString() << ", address: " << device_address->GetMutablePtr(); | |||||
| } | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| UpdateHostSwapInQueue(device_address, mock); | UpdateHostSwapInQueue(device_address, mock); | ||||
| MS_EXCEPTION_IF_NULL(device_address->ptr_); | MS_EXCEPTION_IF_NULL(device_address->ptr_); | ||||
| @@ -969,6 +1033,14 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) | |||||
| } | } | ||||
| // Free the input of kernel by reference count. | // Free the input of kernel by reference count. | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | ||||
| if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||||
| auto index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index")); | |||||
| if (i == index) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); | auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); | ||||
| if (kernel_ref_count_ptr == nullptr) { | if (kernel_ref_count_ptr == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -94,6 +94,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); | void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); | ||||
| void UpdateHostSwapOutQueue(bool mock); | void UpdateHostSwapOutQueue(bool mock); | ||||
| void ClearSwapInfo(bool mock); | void ClearSwapInfo(bool mock); | ||||
| void AllocInplaceNodeMemory(const session::KernelGraph *graph); | |||||
| std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | ||||
| std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | ||||
| std::unordered_map<uint32_t, bool> is_first_step_map_; | std::unordered_map<uint32_t, bool> is_first_step_map_; | ||||
| @@ -193,6 +193,7 @@ constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | |||||
| constexpr auto kPaddingOpName = "Padding"; | constexpr auto kPaddingOpName = "Padding"; | ||||
| constexpr auto kAvgPoolOpName = "AvgPool"; | constexpr auto kAvgPoolOpName = "AvgPool"; | ||||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | ||||
| constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | |||||
| constexpr auto kTensorAddOpName = "TensorAdd"; | constexpr auto kTensorAddOpName = "TensorAdd"; | ||||
| constexpr auto kCastOpName = "Cast"; | constexpr auto kCastOpName = "Cast"; | ||||
| constexpr auto kGreaterEqualOpName = "GreaterEqual"; | constexpr auto kGreaterEqualOpName = "GreaterEqual"; | ||||
| @@ -0,0 +1,75 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| class Conv2dBpropInputInplace(nn.Cell): | |||||
| def __init__(self, w1, w2): | |||||
| super(Conv2dBpropInputInplace, self).__init__() | |||||
| self.conv2d_1 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1) | |||||
| self.w1 = Parameter(initializer(w1, w1.shape), name='w1') | |||||
| self.conv2d_2 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1) | |||||
| self.w2 = Parameter(initializer(w2, w2.shape), name='w2') | |||||
| self.add = P.TensorAdd() | |||||
| self.maxpool = P.MaxPool(ksize=3, strides=2, padding='SAME') | |||||
| self.maxpool_grad = G.MaxPoolGrad(ksize=3, strides=2, padding='SAME') | |||||
| self.shape = (32, 64, 56, 56) | |||||
| def construct(self, x1, x2, x3): | |||||
| dx1 = self.conv2d_1(x1, self.w1, self.shape) | |||||
| dx2 = self.conv2d_2(x2, self.w2, self.shape) | |||||
| dx = self.add(dx1, dx2) | |||||
| y = self.maxpool(x3) | |||||
| y = self.maxpool_grad(x3, y, dx) | |||||
| return y | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_inplace_fusion1(): | |||||
| np.random.seed(42) | |||||
| w1_np = np.random.randn(64, 64, 1, 1) | |||||
| w2_np = np.random.randn(256, 64, 1, 1) | |||||
| x1_np = np.random.randn(32, 64, 56, 56) | |||||
| x2_np = np.random.randn(32, 256, 56, 56) | |||||
| x3_np = np.random.randn(32, 64, 112, 112) | |||||
| w1 = Tensor(w1_np.astype(np.float32)) | |||||
| w2 = Tensor(w2_np.astype(np.float32)) | |||||
| x1 = Tensor(x1_np.astype(np.float32)) | |||||
| x2 = Tensor(x2_np.astype(np.float32)) | |||||
| x3 = Tensor(x3_np.astype(np.float32)) | |||||
| net = Conv2dBpropInputInplace(w1, w2) | |||||
| context.set_context(device_target='GPU', mode=context.GRAPH_MODE) | |||||
| fusion_output = net(x1, x2, x3) | |||||
| context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE) | |||||
| no_fusion_output = net(x1, x2, x3) | |||||
| assert np.allclose(fusion_output.asnumpy(), no_fusion_output.asnumpy(), atol=2e-5) | |||||