| @@ -77,6 +77,15 @@ class GpuKernel : public KernelMod { | |||
| } | |||
| return GetValue<T>(attr); | |||
| } | |||
| template <typename T> | |||
| inline T GetAttrWithDefault(const CNodePtr &kernel_node, const std::string &key, const T &value) const { | |||
| const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||
| const ValuePtr &attr = prim->GetAttr(key); | |||
| if (attr == nullptr) { | |||
| return value; | |||
| } | |||
| return GetValue<T>(attr); | |||
| } | |||
| // expand Nd Shape to 4d (N in [0,4]) | |||
| void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) { | |||
| if (src.size() > 4) { | |||
| @@ -54,7 +54,8 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| output_size_(0), | |||
| padded_size_(0), | |||
| workspace_size_(0), | |||
| use_pad_(true) {} | |||
| use_pad_(true), | |||
| beta_(0) {} | |||
| ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -75,13 +76,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| } | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded = GetDeviceAddress<T>(workspace, 1); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| workspace_size_, &beta, padded_descriptor_, padded), | |||
| workspace_size_, &beta_, padded_descriptor_, padded), | |||
| "ConvolutionBackwardData failed"); | |||
| if (data_format_ == "NHWC") { | |||
| CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, | |||
| @@ -93,7 +93,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| workspace_size_, &beta, dx_desc_, dx), | |||
| workspace_size_, &beta_, dx_desc_, dx), | |||
| "ConvolutionBackwardData failed"); | |||
| } | |||
| return true; | |||
| @@ -188,6 +188,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| "cudnnSetConvolutionMathType failed.") | |||
| } | |||
| SelectAlgorithm(dx_desc_real); | |||
| beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -349,6 +350,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| size_t padded_size_; | |||
| size_t workspace_size_; | |||
| bool use_pad_; | |||
| float beta_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -46,7 +46,8 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| scale_bias_diff_desc_(nullptr), | |||
| activation_desc_(nullptr), | |||
| handle_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT) {} | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| beta_data_diff_(0) {} | |||
| ~FusedBatchNormGradExGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -88,12 +89,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| } | |||
| const float alpha_data_diff = 1; | |||
| const float beta_data_diff = 0; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationBackwardEx( | |||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, | |||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, | |||
| &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, | |||
| scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance, | |||
| activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_), | |||
| @@ -141,6 +140,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| std::string format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; | |||
| SetTensorDescriptor(format, shape); | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -285,6 +285,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| cudnnHandle_t handle_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| float beta_data_diff_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -0,0 +1,195 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/cudnn_inplace_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <map> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/contract.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| struct AnfNodeIndex { | |||
| AnfNodeIndex() : node(nullptr), index(0) {} | |||
| AnfNodeIndex(const AnfNodePtr &n, const int i) : node(n), index(i) {} | |||
| AnfNodePtr node; | |||
| uint32_t index; | |||
| }; | |||
| // opname, output idx | |||
| std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, | |||
| {kFusedBatchNormGradExWithAddAndActivation, 3}}; | |||
| std::set<string> kSkipOpNames = { | |||
| kTensorAddOpName, | |||
| }; | |||
| // opname, input idx | |||
| std::map<string, uint32_t> kAggregatesOpNames = { | |||
| {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; | |||
| template <typename T> | |||
| void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| primitive->AddAttr(key, MakeValue(value)); | |||
| } | |||
| void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node) { | |||
| SetPrimAttr(aggregate_node.node, "aggregate", true); | |||
| SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index); | |||
| SetPrimAttr(skip_node, "skip", true); | |||
| static uint32_t group = 0; | |||
| for (size_t i = 0; i < inplace_node->size(); i++) { | |||
| auto algo = (i == 0) ? "cover" : "accumulation"; | |||
| SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo); | |||
| SetPrimAttr((*inplace_node)[i].node, "inplace_group", group); | |||
| SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index); | |||
| } | |||
| group++; | |||
| } | |||
| void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) { | |||
| std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||
| inplace_nodes[0].node, inplace_nodes[1].node}; | |||
| auto control_depend_node = graph->NewCNode(inputs1); | |||
| auto return_node = graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| // mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor` | |||
| auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(return_input); | |||
| std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), | |||
| return_input->input(kFirstDataInputIndex), control_depend_node}; | |||
| auto depend_node = graph->NewCNode(inputs2); | |||
| return_node->set_input(kFirstDataInputIndex, depend_node); | |||
| } | |||
| bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, | |||
| std::vector<AnfNodeIndex> *inplace) { | |||
| MS_EXCEPTION_IF_NULL(skip_node); | |||
| MS_EXCEPTION_IF_NULL(aggregate); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto aggregate_iter = kAggregatesOpNames.find(AnfAlgo::GetCNodeName(node)); | |||
| if (aggregate_iter == kAggregatesOpNames.end()) { | |||
| return false; | |||
| } | |||
| aggregate->node = node; | |||
| aggregate->index = aggregate_iter->second; | |||
| *skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), aggregate_iter->second); | |||
| if (*skip_node == nullptr || !(*skip_node)->isa<CNode>() || | |||
| kSkipOpNames.count(AnfAlgo::GetCNodeName(*skip_node)) == 0 || | |||
| GetRealNodeUsedList(graph, *skip_node)->size() >= 2) { | |||
| return false; | |||
| } | |||
| auto cnode = (*skip_node)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||
| auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i); | |||
| if (!inplace_node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| // Check Inplace nodes have no user except TensorAdd nodes | |||
| if (GetRealNodeUsedList(graph, inplace_node)->size() >= 2) { | |||
| return false; | |||
| } | |||
| // skip TupleGetItem node | |||
| if (AnfAlgo::GetCNodeName(inplace_node) == prim::kPrimTupleGetItem->name()) { | |||
| inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(inplace_node), 0); | |||
| } | |||
| auto inplace_iter = kInplaceOpNames.find(AnfAlgo::GetCNodeName(inplace_node)); | |||
| if (inplace_iter == kInplaceOpNames.end()) { | |||
| return false; | |||
| } | |||
| inplace->push_back(AnfNodeIndex(inplace_node, inplace_iter->second)); | |||
| } | |||
| return true; | |||
| } | |||
| std::map<AnfNodePtr, int> TopoIndex(const std::vector<AnfNodePtr> &node_list) { | |||
| std::map<AnfNodePtr, int> topo_index; | |||
| for (size_t i = 0; i < node_list.size(); i++) { | |||
| topo_index.insert(make_pair(node_list[i], i)); | |||
| } | |||
| return topo_index; | |||
| } | |||
| } // namespace | |||
| bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | |||
| auto topo_index = TopoIndex(node_list); | |||
| for (auto node : node_list) { | |||
| AnfNodeIndex aggregate_node; | |||
| AnfNodePtr skip_node; | |||
| std::vector<AnfNodeIndex> inplace_node; | |||
| // 1. Pattern Match. | |||
| if (!PatternMatch(graph, node, &aggregate_node, &skip_node, &inplace_node)) { | |||
| continue; | |||
| } | |||
| // 2. Keep the original topological order in case the dependence between inplace nodes | |||
| std::sort(inplace_node.begin(), inplace_node.end(), [&topo_index](const AnfNodeIndex &n1, const AnfNodeIndex &n2) { | |||
| auto iter1 = topo_index.find(n1.node); | |||
| auto iter2 = topo_index.find(n2.node); | |||
| if (iter1 == topo_index.end() || iter2 == topo_index.end()) { | |||
| MS_LOG(EXCEPTION) << ": Node not existed in topo order. node1: " << n1.node->DebugString() | |||
| << ", node2: " << n2.node->DebugString(); | |||
| } | |||
| if (iter1->second < iter2->second) { | |||
| return true; | |||
| } | |||
| return false; | |||
| }); | |||
| MS_LOG(INFO) << "[inplace optimizer] aggregate node: " << aggregate_node.index << ", " | |||
| << aggregate_node.node->DebugString() << "; skip node: " << skip_node->DebugString() << std::endl | |||
| << "; inplace node 0: " << inplace_node[0].index << ", " << inplace_node[0].node->DebugString() | |||
| << std::endl | |||
| << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString() | |||
| << std::endl; | |||
| // 2. Set Node attr | |||
| SetNodeAttr(aggregate_node, skip_node, &inplace_node); | |||
| // 3. Set dependence for inplace nodes | |||
| InsertControlDependToGraph(graph, inplace_node); | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class CudnnInplaceAggregate : public Pass { | |||
| public: | |||
| CudnnInplaceAggregate() : Pass("cudnn_inplace_aggregate") {} | |||
| ~CudnnInplaceAggregate() override = default; | |||
| bool Run(const FuncGraphPtr &g) override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include <algorithm> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -191,6 +192,33 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { | |||
| return adjacent_with_communication_op; | |||
| } | |||
| bool MemSwapManager::IsInplaceRelevantOp(const TensorInfo &tensor) { | |||
| MS_EXCEPTION_IF_NULL(tensor.kernel_); | |||
| if (AnfAlgo::IsInplaceNode(tensor.kernel_, "inplace_algo") || AnfAlgo::IsInplaceNode(tensor.kernel_, "skip")) { | |||
| return true; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_graph_); | |||
| const auto &graph_manager = kernel_graph_->manager(); | |||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||
| NodeUsersMap &user_map = graph_manager->node_users(); | |||
| auto users = user_map.find(tensor.kernel_); | |||
| for (const auto &user : users->second) { | |||
| if (!AnfAlgo::IsInplaceNode(user.first, "aggregate")) { | |||
| continue; | |||
| } | |||
| auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user.first, user.second); | |||
| if (tensor.output_idx_ == kernel_with_index.second) { | |||
| MS_LOG(INFO) << " [inplace optimizer] tensor: " << tensor.kernel_->DebugString() | |||
| << "output idx: " << tensor.output_idx_ << " used by aggregate node: " << user.first->DebugString(); | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void MemSwapManager::SaveUserKernelTopoOrder() { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph_); | |||
| const auto &graph_manager = kernel_graph_->manager(); | |||
| @@ -231,7 +259,10 @@ void MemSwapManager::AddSwapInfo() { | |||
| } | |||
| const AnfNodePtr &kernel = tensor.kernel_; | |||
| if (IsCommunicationRelevantOp(kernel)) { | |||
| bool filter = IsCommunicationRelevantOp(kernel) || IsInplaceRelevantOp(tensor); | |||
| if (filter) { | |||
| MS_LOG(INFO) << " [inplace optimizer] ignore swap tensor: " << kernel->DebugString() << ", index" | |||
| << tensor.output_idx_; | |||
| continue; | |||
| } | |||
| @@ -136,6 +136,8 @@ class MemSwapManager { | |||
| bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; | |||
| bool IsInplaceRelevantOp(const TensorInfo &tensor); | |||
| std::vector<CNodePtr> execution_order_; | |||
| std::vector<TensorInfo> ordered_tensors_; | |||
| std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_; | |||
| @@ -1031,6 +1031,21 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i | |||
| node->set_input(index + 1, input_node); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||
| if (!primitive) { | |||
| return false; | |||
| } | |||
| auto inplace_attr = primitive->GetAttr(type); | |||
| if (inplace_attr == nullptr) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| @@ -206,6 +206,7 @@ class AnfRuntimeAlgorithm { | |||
| // get real input index for some tbe ops which input order is different between me and tbe impl | |||
| static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | |||
| static bool IsCommunicationOp(const AnfNodePtr &node); | |||
| static bool IsInplaceNode(const AnfNodePtr &node, const string &type); | |||
| static bool IsGetNext(const NotNull<AnfNodePtr> &node); | |||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | |||
| static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode); | |||
| @@ -39,6 +39,7 @@ | |||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | |||
| #include "backend/optimizer/gpu/cudnn_inplace_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "runtime/device/gpu/gpu_kernel_runtime.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include "pybind11/pybind11.h" | |||
| #include "runtime/device/gpu/gpu_device_address.h" | |||
| #include "runtime/device/gpu/cuda_driver.h" | |||
| @@ -279,6 +280,51 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v | |||
| ClearOutputAddress(inputs, value_nodes, execution_order); | |||
| } | |||
| void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) { | |||
| std::map<uint32_t, std::vector<CNodePtr>> inplace_groups; | |||
| auto kernel_cnodes = graph->execution_order(); | |||
| for (auto &kernel : kernel_cnodes) { | |||
| if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) { | |||
| continue; | |||
| } | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||
| auto group_attr = primitive->GetAttr("inplace_group"); | |||
| MS_EXCEPTION_IF_NULL(group_attr); | |||
| auto group_id = GetValue<uint32_t>(group_attr); | |||
| inplace_groups[group_id].push_back(kernel); | |||
| } | |||
| for (auto &group : inplace_groups) { | |||
| auto &item = group.second; | |||
| // in-place compute when group size >= 2. | |||
| if (item.size() < 2) { | |||
| continue; | |||
| } | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(item[0]); | |||
| auto output_index = GetValue<uint32_t>(primitive->GetAttr("inplace_output_index")); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(item[0], output_index, false); | |||
| if (device_address->GetPtr() != nullptr) { | |||
| continue; | |||
| } | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(item[0]); | |||
| auto output_size = kernel_mod->GetOutputSizeList(); | |||
| auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_size[output_index]); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address, tensor size is: " << output_size[output_index]; | |||
| } | |||
| for (auto &node : item) { | |||
| auto prim = AnfAlgo::GetCNodePrimitive(node); | |||
| auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index")); | |||
| AnfAlgo::SetOutputAddr(device_address, index, node.get()); | |||
| MS_LOG(INFO) << "[inplace optimizer] group id: " << group.first << ", node: " << node->DebugString() | |||
| << ", output_index: " << index; | |||
| } | |||
| } | |||
| } | |||
| void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -545,6 +591,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De | |||
| mem_reuse_util_->ResetDynamicUsedRefCount(); | |||
| // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | |||
| AllocCommunicationOpDynamicRes(graph); | |||
| AllocInplaceNodeMemory(graph); | |||
| #ifdef ENABLE_DEBUGGER | |||
| debugger_ = debugger; | |||
| @@ -562,6 +609,10 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De | |||
| for (const auto &kernel : kernels) { | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| if (AnfAlgo::IsInplaceNode(kernel, "skip")) { | |||
| MS_LOG(INFO) << "[inplace optimizer] skip node: " << kernel->DebugString(); | |||
| continue; | |||
| } | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| @@ -808,6 +859,19 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k | |||
| // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. | |||
| device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); | |||
| } | |||
| // Get in-place output_address | |||
| if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||
| auto input_index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index")); | |||
| if (i == input_index) { | |||
| auto skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(kernel), input_index); | |||
| device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(skip_node, 0, false); | |||
| MS_LOG(INFO) << "[inplace optimizer] aggregate: " << kernel->DebugString() | |||
| << ", skip: " << skip_node->DebugString() << ", address: " << device_address->GetMutablePtr(); | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| UpdateHostSwapInQueue(device_address, mock); | |||
| MS_EXCEPTION_IF_NULL(device_address->ptr_); | |||
| @@ -969,6 +1033,14 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) | |||
| } | |||
| // Free the input of kernel by reference count. | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | |||
| if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel); | |||
| auto index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index")); | |||
| if (i == index) { | |||
| continue; | |||
| } | |||
| } | |||
| auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); | |||
| if (kernel_ref_count_ptr == nullptr) { | |||
| continue; | |||
| @@ -94,6 +94,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||
| void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); | |||
| void UpdateHostSwapOutQueue(bool mock); | |||
| void ClearSwapInfo(bool mock); | |||
| void AllocInplaceNodeMemory(const session::KernelGraph *graph); | |||
| std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | |||
| std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | |||
| std::unordered_map<uint32_t, bool> is_first_step_map_; | |||
| @@ -193,6 +193,7 @@ constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | |||
| constexpr auto kPaddingOpName = "Padding"; | |||
| constexpr auto kAvgPoolOpName = "AvgPool"; | |||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||
| constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | |||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||
| constexpr auto kCastOpName = "Cast"; | |||
| constexpr auto kGreaterEqualOpName = "GreaterEqual"; | |||
| @@ -0,0 +1,75 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class Conv2dBpropInputInplace(nn.Cell): | |||
| def __init__(self, w1, w2): | |||
| super(Conv2dBpropInputInplace, self).__init__() | |||
| self.conv2d_1 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1) | |||
| self.w1 = Parameter(initializer(w1, w1.shape), name='w1') | |||
| self.conv2d_2 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1) | |||
| self.w2 = Parameter(initializer(w2, w2.shape), name='w2') | |||
| self.add = P.TensorAdd() | |||
| self.maxpool = P.MaxPool(ksize=3, strides=2, padding='SAME') | |||
| self.maxpool_grad = G.MaxPoolGrad(ksize=3, strides=2, padding='SAME') | |||
| self.shape = (32, 64, 56, 56) | |||
| def construct(self, x1, x2, x3): | |||
| dx1 = self.conv2d_1(x1, self.w1, self.shape) | |||
| dx2 = self.conv2d_2(x2, self.w2, self.shape) | |||
| dx = self.add(dx1, dx2) | |||
| y = self.maxpool(x3) | |||
| y = self.maxpool_grad(x3, y, dx) | |||
| return y | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_inplace_fusion1(): | |||
| np.random.seed(42) | |||
| w1_np = np.random.randn(64, 64, 1, 1) | |||
| w2_np = np.random.randn(256, 64, 1, 1) | |||
| x1_np = np.random.randn(32, 64, 56, 56) | |||
| x2_np = np.random.randn(32, 256, 56, 56) | |||
| x3_np = np.random.randn(32, 64, 112, 112) | |||
| w1 = Tensor(w1_np.astype(np.float32)) | |||
| w2 = Tensor(w2_np.astype(np.float32)) | |||
| x1 = Tensor(x1_np.astype(np.float32)) | |||
| x2 = Tensor(x2_np.astype(np.float32)) | |||
| x3 = Tensor(x3_np.astype(np.float32)) | |||
| net = Conv2dBpropInputInplace(w1, w2) | |||
| context.set_context(device_target='GPU', mode=context.GRAPH_MODE) | |||
| fusion_output = net(x1, x2, x3) | |||
| context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE) | |||
| no_fusion_output = net(x1, x2, x3) | |||
| assert np.allclose(fusion_output.asnumpy(), no_fusion_output.asnumpy(), atol=2e-5) | |||