Browse Source

cudnn inplace optimizer

tags/v1.1.0
wilfChen 5 years ago
parent
commit
b420b6cda7
14 changed files with 447 additions and 9 deletions
  1. +9
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  2. +6
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h
  3. +5
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h
  4. +195
    -0
      mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc
  5. +32
    -0
      mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.h
  6. +32
    -1
      mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc
  7. +2
    -0
      mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h
  8. +15
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  9. +1
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  10. +1
    -0
      mindspore/ccsrc/backend/session/gpu_session.cc
  11. +72
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
  12. +1
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h
  13. +1
    -0
      mindspore/ccsrc/utils/utils.h
  14. +75
    -0
      tests/st/ops/gpu/test_cudnn_inplace_fusion.py

+ 9
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -77,6 +77,15 @@ class GpuKernel : public KernelMod {
}
return GetValue<T>(attr);
}
template <typename T>
inline T GetAttrWithDefault(const CNodePtr &kernel_node, const std::string &key, const T &value) const {
const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node);
const ValuePtr &attr = prim->GetAttr(key);
if (attr == nullptr) {
return value;
}
return GetValue<T>(attr);
}
// expand Nd Shape to 4d (N in [0,4])
void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) {
if (src.size() > 4) {


+ 6
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h View File

@@ -54,7 +54,8 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
output_size_(0),
padded_size_(0),
workspace_size_(0),
use_pad_(true) {}
use_pad_(true),
beta_(0) {}
~ConvGradInputGpuBkwKernel() override { DestroyResource(); }

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@@ -75,13 +76,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
}

const float alpha = 1;
const float beta = 0;
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);

CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta, padded_descriptor_, padded),
workspace_size_, &beta_, padded_descriptor_, padded),
"ConvolutionBackwardData failed");
if (data_format_ == "NHWC") {
CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
@@ -93,7 +93,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta, dx_desc_, dx),
workspace_size_, &beta_, dx_desc_, dx),
"ConvolutionBackwardData failed");
}
return true;
@@ -188,6 +188,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(dx_desc_real);
beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
InitSizeLists();
return true;
}
@@ -349,6 +350,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
size_t padded_size_;
size_t workspace_size_;
bool use_pad_;
float beta_;
};
} // namespace kernel
} // namespace mindspore


+ 5
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h View File

@@ -46,7 +46,8 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
scale_bias_diff_desc_(nullptr),
activation_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
cudnn_data_type_(CUDNN_DATA_FLOAT),
beta_data_diff_(0) {}
~FusedBatchNormGradExGpuKernel() override { DestroyResource(); }

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@@ -88,12 +89,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
}

const float alpha_data_diff = 1;
const float beta_data_diff = 0;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;

CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff,
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx,
scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance,
activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_),
@@ -141,6 +140,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
return true;
}
std::string format = AnfAlgo::GetInputFormat(kernel_node, 0);
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
SetTensorDescriptor(format, shape);
InitSizeLists();
return true;
@@ -285,6 +285,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {

cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
float beta_data_diff_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;


+ 195
- 0
mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc View File

@@ -0,0 +1,195 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/gpu/cudnn_inplace_fusion.h"

#include <memory>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <utility>
#include <string>

#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "utils/contract.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"

namespace mindspore {
namespace opt {
namespace {
struct AnfNodeIndex {
AnfNodeIndex() : node(nullptr), index(0) {}
AnfNodeIndex(const AnfNodePtr &n, const int i) : node(n), index(i) {}
AnfNodePtr node;
uint32_t index;
};

// opname, output idx
std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0},
{kFusedBatchNormGradExWithAddAndActivation, 3}};

std::set<string> kSkipOpNames = {
kTensorAddOpName,
};

// opname, input idx
std::map<string, uint32_t> kAggregatesOpNames = {
{kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}};

template <typename T>
void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node);
MS_EXCEPTION_IF_NULL(primitive);
primitive->AddAttr(key, MakeValue(value));
}

void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node) {
SetPrimAttr(aggregate_node.node, "aggregate", true);
SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
SetPrimAttr(skip_node, "skip", true);

static uint32_t group = 0;
for (size_t i = 0; i < inplace_node->size(); i++) {
auto algo = (i == 0) ? "cover" : "accumulation";
SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo);
SetPrimAttr((*inplace_node)[i].node, "inplace_group", group);
SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index);
}
group++;
}

void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) {
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
inplace_nodes[0].node, inplace_nodes[1].node};
auto control_depend_node = graph->NewCNode(inputs1);

auto return_node = graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
// mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor`
auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_input);
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_input->input(kFirstDataInputIndex), control_depend_node};
auto depend_node = graph->NewCNode(inputs2);
return_node->set_input(kFirstDataInputIndex, depend_node);
}

bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
std::vector<AnfNodeIndex> *inplace) {
MS_EXCEPTION_IF_NULL(skip_node);
MS_EXCEPTION_IF_NULL(aggregate);
if (!node->isa<CNode>()) {
return false;
}
auto aggregate_iter = kAggregatesOpNames.find(AnfAlgo::GetCNodeName(node));
if (aggregate_iter == kAggregatesOpNames.end()) {
return false;
}
aggregate->node = node;
aggregate->index = aggregate_iter->second;

*skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), aggregate_iter->second);
if (*skip_node == nullptr || !(*skip_node)->isa<CNode>() ||
kSkipOpNames.count(AnfAlgo::GetCNodeName(*skip_node)) == 0 ||
GetRealNodeUsedList(graph, *skip_node)->size() >= 2) {
return false;
}

auto cnode = (*skip_node)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) {
auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i);
if (!inplace_node->isa<CNode>()) {
return false;
}
// Check Inplace nodes have no user except TensorAdd nodes
if (GetRealNodeUsedList(graph, inplace_node)->size() >= 2) {
return false;
}

// skip TupleGetItem node
if (AnfAlgo::GetCNodeName(inplace_node) == prim::kPrimTupleGetItem->name()) {
inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(inplace_node), 0);
}

auto inplace_iter = kInplaceOpNames.find(AnfAlgo::GetCNodeName(inplace_node));
if (inplace_iter == kInplaceOpNames.end()) {
return false;
}

inplace->push_back(AnfNodeIndex(inplace_node, inplace_iter->second));
}

return true;
}

std::map<AnfNodePtr, int> TopoIndex(const std::vector<AnfNodePtr> &node_list) {
std::map<AnfNodePtr, int> topo_index;
for (size_t i = 0; i < node_list.size(); i++) {
topo_index.insert(make_pair(node_list[i], i));
}
return topo_index;
}
} // namespace

bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
auto topo_index = TopoIndex(node_list);

for (auto node : node_list) {
AnfNodeIndex aggregate_node;
AnfNodePtr skip_node;
std::vector<AnfNodeIndex> inplace_node;
// 1. Pattern Match.
if (!PatternMatch(graph, node, &aggregate_node, &skip_node, &inplace_node)) {
continue;
}

// 2. Keep the original topological order in case the dependence between inplace nodes
std::sort(inplace_node.begin(), inplace_node.end(), [&topo_index](const AnfNodeIndex &n1, const AnfNodeIndex &n2) {
auto iter1 = topo_index.find(n1.node);
auto iter2 = topo_index.find(n2.node);
if (iter1 == topo_index.end() || iter2 == topo_index.end()) {
MS_LOG(EXCEPTION) << ": Node not existed in topo order. node1: " << n1.node->DebugString()
<< ", node2: " << n2.node->DebugString();
}

if (iter1->second < iter2->second) {
return true;
}
return false;
});
MS_LOG(INFO) << "[inplace optimizer] aggregate node: " << aggregate_node.index << ", "
<< aggregate_node.node->DebugString() << "; skip node: " << skip_node->DebugString() << std::endl
<< "; inplace node 0: " << inplace_node[0].index << ", " << inplace_node[0].node->DebugString()
<< std::endl
<< "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString()
<< std::endl;
// 2. Set Node attr
SetNodeAttr(aggregate_node, skip_node, &inplace_node);
// 3. Set dependence for inplace nodes
InsertControlDependToGraph(graph, inplace_node);
}

return true;
}
} // namespace opt
} // namespace mindspore

+ 32
- 0
mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_

#include <memory>
#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class CudnnInplaceAggregate : public Pass {
public:
CudnnInplaceAggregate() : Pass("cudnn_inplace_aggregate") {}
~CudnnInplaceAggregate() override = default;
bool Run(const FuncGraphPtr &g) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CUDNN_INPLACE_AGGREGATE_FUSION_H_

+ 32
- 1
mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc View File

@@ -18,6 +18,7 @@
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_runtime_manager.h"

namespace mindspore {
namespace device {
@@ -191,6 +192,33 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
return adjacent_with_communication_op;
}

bool MemSwapManager::IsInplaceRelevantOp(const TensorInfo &tensor) {
MS_EXCEPTION_IF_NULL(tensor.kernel_);
if (AnfAlgo::IsInplaceNode(tensor.kernel_, "inplace_algo") || AnfAlgo::IsInplaceNode(tensor.kernel_, "skip")) {
return true;
}

MS_EXCEPTION_IF_NULL(kernel_graph_);
const auto &graph_manager = kernel_graph_->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
NodeUsersMap &user_map = graph_manager->node_users();

auto users = user_map.find(tensor.kernel_);
for (const auto &user : users->second) {
if (!AnfAlgo::IsInplaceNode(user.first, "aggregate")) {
continue;
}

auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user.first, user.second);
if (tensor.output_idx_ == kernel_with_index.second) {
MS_LOG(INFO) << " [inplace optimizer] tensor: " << tensor.kernel_->DebugString()
<< "output idx: " << tensor.output_idx_ << " used by aggregate node: " << user.first->DebugString();
return true;
}
}
return false;
}

void MemSwapManager::SaveUserKernelTopoOrder() {
MS_EXCEPTION_IF_NULL(kernel_graph_);
const auto &graph_manager = kernel_graph_->manager();
@@ -231,7 +259,10 @@ void MemSwapManager::AddSwapInfo() {
}

const AnfNodePtr &kernel = tensor.kernel_;
if (IsCommunicationRelevantOp(kernel)) {
bool filter = IsCommunicationRelevantOp(kernel) || IsInplaceRelevantOp(tensor);
if (filter) {
MS_LOG(INFO) << " [inplace optimizer] ignore swap tensor: " << kernel->DebugString() << ", index"
<< tensor.output_idx_;
continue;
}



+ 2
- 0
mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h View File

@@ -136,6 +136,8 @@ class MemSwapManager {

bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;

bool IsInplaceRelevantOp(const TensorInfo &tensor);

std::vector<CNodePtr> execution_order_;
std::vector<TensorInfo> ordered_tensors_;
std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_;


+ 15
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1031,6 +1031,21 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i
node->set_input(index + 1, input_node);
}

bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) {
MS_EXCEPTION_IF_NULL(kernel);
auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
if (!primitive) {
return false;
}

auto inplace_attr = primitive->GetAttr(type);
if (inplace_attr == nullptr) {
return false;
}

return true;
}

bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {


+ 1
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -206,6 +206,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsInplaceNode(const AnfNodePtr &node, const string &type);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);


+ 1
- 0
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -39,6 +39,7 @@
#include "backend/optimizer/gpu/insert_format_transform_op.h"
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
#include "backend/optimizer/gpu/cudnn_inplace_fusion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"


+ 72
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -15,6 +15,7 @@
*/
#include "runtime/device/gpu/gpu_kernel_runtime.h"
#include <algorithm>
#include <map>
#include "pybind11/pybind11.h"
#include "runtime/device/gpu/gpu_device_address.h"
#include "runtime/device/gpu/cuda_driver.h"
@@ -279,6 +280,51 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
ClearOutputAddress(inputs, value_nodes, execution_order);
}

void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) {
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
auto kernel_cnodes = graph->execution_order();
for (auto &kernel : kernel_cnodes) {
if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
continue;
}
auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
auto group_attr = primitive->GetAttr("inplace_group");
MS_EXCEPTION_IF_NULL(group_attr);
auto group_id = GetValue<uint32_t>(group_attr);
inplace_groups[group_id].push_back(kernel);
}

for (auto &group : inplace_groups) {
auto &item = group.second;
// in-place compute when group size >= 2.
if (item.size() < 2) {
continue;
}

auto primitive = AnfAlgo::GetCNodePrimitive(item[0]);
auto output_index = GetValue<uint32_t>(primitive->GetAttr("inplace_output_index"));
auto device_address = AnfAlgo::GetMutableOutputAddr(item[0], output_index, false);
if (device_address->GetPtr() != nullptr) {
continue;
}

auto kernel_mod = AnfAlgo::GetKernelMod(item[0]);
auto output_size = kernel_mod->GetOutputSizeList();
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_size[output_index]);
if (!ret) {
MS_LOG(EXCEPTION) << "Cannot alloc address, tensor size is: " << output_size[output_index];
}

for (auto &node : item) {
auto prim = AnfAlgo::GetCNodePrimitive(node);
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
AnfAlgo::SetOutputAddr(device_address, index, node.get());
MS_LOG(INFO) << "[inplace optimizer] group id: " << group.first << ", node: " << node->DebugString()
<< ", output_index: " << index;
}
}
}

void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@@ -545,6 +591,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De
mem_reuse_util_->ResetDynamicUsedRefCount();
// The inputs and outputs memory of communication kernel need be continuous, so separate processing.
AllocCommunicationOpDynamicRes(graph);
AllocInplaceNodeMemory(graph);

#ifdef ENABLE_DEBUGGER
debugger_ = debugger;
@@ -562,6 +609,10 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, De
for (const auto &kernel : kernels) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
if (AnfAlgo::IsInplaceNode(kernel, "skip")) {
MS_LOG(INFO) << "[inplace optimizer] skip node: " << kernel->DebugString();
continue;
}
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
@@ -808,6 +859,19 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k
// Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true);
}

// Get in-place output_address
if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) {
auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
auto input_index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index"));
if (i == input_index) {
auto skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(kernel), input_index);
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(skip_node, 0, false);
MS_LOG(INFO) << "[inplace optimizer] aggregate: " << kernel->DebugString()
<< ", skip: " << skip_node->DebugString() << ", address: " << device_address->GetMutablePtr();
}
}

MS_EXCEPTION_IF_NULL(device_address);
UpdateHostSwapInQueue(device_address, mock);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
@@ -969,6 +1033,14 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
}
// Free the input of kernel by reference count.
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) {
auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
auto index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index"));
if (i == index) {
continue;
}
}

auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i);
if (kernel_ref_count_ptr == nullptr) {
continue;


+ 1
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h View File

@@ -94,6 +94,7 @@ class GPUKernelRuntime : public KernelRuntime {
void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock);
void UpdateHostSwapOutQueue(bool mock);
void ClearSwapInfo(bool mock);
void AllocInplaceNodeMemory(const session::KernelGraph *graph);
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, bool> is_first_step_map_;


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -193,6 +193,7 @@ constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
constexpr auto kTensorAddOpName = "TensorAdd";
constexpr auto kCastOpName = "Cast";
constexpr auto kGreaterEqualOpName = "GreaterEqual";


+ 75
- 0
tests/st/ops/gpu/test_cudnn_inplace_fusion.py View File

@@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G


class Conv2dBpropInputInplace(nn.Cell):
def __init__(self, w1, w2):
super(Conv2dBpropInputInplace, self).__init__()
self.conv2d_1 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1)
self.w1 = Parameter(initializer(w1, w1.shape), name='w1')
self.conv2d_2 = P.Conv2DBackpropInput(out_channel=256, kernel_size=1)
self.w2 = Parameter(initializer(w2, w2.shape), name='w2')
self.add = P.TensorAdd()
self.maxpool = P.MaxPool(ksize=3, strides=2, padding='SAME')
self.maxpool_grad = G.MaxPoolGrad(ksize=3, strides=2, padding='SAME')
self.shape = (32, 64, 56, 56)

def construct(self, x1, x2, x3):
dx1 = self.conv2d_1(x1, self.w1, self.shape)
dx2 = self.conv2d_2(x2, self.w2, self.shape)

dx = self.add(dx1, dx2)
y = self.maxpool(x3)
y = self.maxpool_grad(x3, y, dx)
return y


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_inplace_fusion1():

np.random.seed(42)
w1_np = np.random.randn(64, 64, 1, 1)
w2_np = np.random.randn(256, 64, 1, 1)
x1_np = np.random.randn(32, 64, 56, 56)
x2_np = np.random.randn(32, 256, 56, 56)
x3_np = np.random.randn(32, 64, 112, 112)

w1 = Tensor(w1_np.astype(np.float32))
w2 = Tensor(w2_np.astype(np.float32))
x1 = Tensor(x1_np.astype(np.float32))
x2 = Tensor(x2_np.astype(np.float32))
x3 = Tensor(x3_np.astype(np.float32))

net = Conv2dBpropInputInplace(w1, w2)
context.set_context(device_target='GPU', mode=context.GRAPH_MODE)
fusion_output = net(x1, x2, x3)

context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE)
no_fusion_output = net(x1, x2, x3)

assert np.allclose(fusion_output.asnumpy(), no_fusion_output.asnumpy(), atol=2e-5)

Loading…
Cancel
Save